diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a5eadadf8604..71a8aea59859 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,3 +5,9 @@ # Flower Baselines /baselines @jafermarq @tanertopal @danieljanes + +# Flower Examples +/examples @jafermarq @tanertopal @danieljanes + +# Changelog +/doc/source/ref-changelog.md @jafermarq @tanertopal @danieljanes diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8d73ed618919..479f88c1bbd5 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -37,10 +37,29 @@ Example: The variable `rnd` was renamed to `server_round` to improve readability - [ ] Implement proposed change - [ ] Write tests - [ ] Update [documentation](https://flower.dev/docs/writing-documentation.html) -- [ ] Update [changelog](https://github.com/adap/flower/blob/main/doc/source/changelog.rst) +- [ ] Update the changelog entry below - [ ] Make CI checks pass - [ ] Ping maintainers on [Slack](https://flower.dev/join-slack/) (channel `#contributions`) + + +### Changelog entry + + + ### Any other comments? diff --git a/baselines/heterofl/heterofl/__init__.py b/baselines/heterofl/heterofl/__init__.py new file mode 100644 index 000000000000..a5e567b59135 --- /dev/null +++ b/baselines/heterofl/heterofl/__init__.py @@ -0,0 +1 @@ +"""Template baseline package.""" diff --git a/baselines/heterofl/heterofl/client.py b/baselines/heterofl/heterofl/client.py new file mode 100644 index 000000000000..cf325cb7e85b --- /dev/null +++ b/baselines/heterofl/heterofl/client.py @@ -0,0 +1,133 @@ +"""Defines the MNIST Flower Client and a function to instantiate it.""" + +from typing import Callable, Dict, List, Optional, Tuple + +import flwr as fl +import torch +from flwr.common.typing import NDArrays + +from heterofl.models import create_model, get_parameters, set_parameters, test, train + +# from torch.utils.data import DataLoader + + +class FlowerNumPyClient(fl.client.NumPyClient): + """Standard Flower client for training.""" + + def __init__( + self, + # cid: str, + net: torch.nn.Module, + dataloader, + model_rate: Optional[float], + client_train_settings: Dict, + ): + # self.cid = cid + self.net = net + self.trainloader = dataloader["trainloader"] + self.label_split = dataloader["label_split"] + self.valloader = dataloader["valloader"] + self.model_rate = model_rate + self.client_train_settings = client_train_settings + self.client_train_settings["device"] = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) + # print( + # "Client_with model rate = {} , cid of client = {}".format( + # self.model_rate, self.cid + # ) + # ) + + def get_parameters(self, config) -> NDArrays: + """Return the parameters of the current net.""" + # print(f"[Client {self.cid}] get_parameters") + return get_parameters(self.net) + + def fit(self, parameters, config) -> Tuple[NDArrays, int, Dict]: + """Implement distributed fit function for a given client.""" + # print(f"cid = {self.cid}") + set_parameters(self.net, parameters) + if "lr" in config: + self.client_train_settings["lr"] = config["lr"] + train( + self.net, + self.trainloader, + self.label_split, + self.client_train_settings, + ) + return get_parameters(self.net), len(self.trainloader), {} + + def evaluate(self, parameters, config) -> Tuple[float, int, Dict]: + """Implement distributed evaluation for a given client.""" + set_parameters(self.net, parameters) + loss, accuracy = test( + self.net, self.valloader, device=self.client_train_settings["device"] + ) + return float(loss), len(self.valloader), {"accuracy": float(accuracy)} + + +def gen_client_fn( + model_config: Dict, + client_to_model_rate_mapping: Optional[List[float]], + client_train_settings: Dict, + data_loaders, +) -> Callable[[str], FlowerNumPyClient]: # pylint: disable=too-many-arguments + """Generate the client function that creates the Flower Clients. + + Parameters + ---------- + model_config : Dict + Dict that contains all the information required to + create a model (data_shape , hidden_layers , classes_size...) + client_to_model_rate: List[float] + List tha contains model_rates of clients. + model_rate of client with cid i = client_to_model_rate_mapping[i] + client_train_settings : Dict + Dict that contains information regarding optimizer , lr , + momentum , device required by the client to train + trainloaders: List[DataLoader] + A list of DataLoaders, each pointing to the dataset training partition + belonging to a particular client. + label_split: torch.tensor + A Tensor of tensors that conatins the labels of the partitioned dataset. + label_split of client with cid i = label_split[i] + valloaders: List[DataLoader] + A list of DataLoaders, each pointing to the dataset validation partition + belonging to a particular client. + + Returns + ------- + Callable[[str], FlowerClient] + A tuple containing the client function that creates Flower Clients + """ + + def client_fn(cid: str) -> FlowerNumPyClient: + """Create a Flower client representing a single organization.""" + # Note: each client gets a different trainloader/valloader, so each client + # will train and evaluate on their own unique data + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + client_dataloader = { + "trainloader": data_loaders["trainloaders"][int(cid)], + "valloader": data_loaders["valloaders"][int(cid)], + "label_split": data_loaders["label_split"][int(cid)], + } + # trainloader = data_loaders["trainloaders"][int(cid)] + # valloader = data_loaders["valloaders"][int(cid)] + model_rate = None + if client_to_model_rate_mapping is not None: + model_rate = client_to_model_rate_mapping[int(cid)] + + return FlowerNumPyClient( + # cid=cid, + net=create_model( + model_config, + model_rate=model_rate, + device=device, + ), + dataloader=client_dataloader, + model_rate=model_rate, + client_train_settings=client_train_settings, + ) + + return client_fn diff --git a/baselines/heterofl/heterofl/client_manager_heterofl.py b/baselines/heterofl/heterofl/client_manager_heterofl.py new file mode 100644 index 000000000000..be5b2227a159 --- /dev/null +++ b/baselines/heterofl/heterofl/client_manager_heterofl.py @@ -0,0 +1,207 @@ +"""HeteroFL ClientManager.""" + +import random +import threading +from logging import INFO +from typing import Dict, List, Optional + +import flwr as fl +import torch +from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy +from flwr.server.criterion import Criterion + +# from heterofl.utils import ModelRateManager + + +class ClientManagerHeteroFL(fl.server.ClientManager): + """Provides a pool of available clients.""" + + def __init__( + self, + model_rate_manager=None, + clients_to_model_rate_mapping=None, + client_label_split: Optional[list[torch.tensor]] = None, + ) -> None: + super().__init__() + self.clients: Dict[str, ClientProxy] = {} + + self.is_simulation = False + if model_rate_manager is not None and clients_to_model_rate_mapping is not None: + self.is_simulation = True + + self.model_rate_manager = model_rate_manager + + # have a common array in simulation to access in the client_fn and server side + if self.is_simulation is True: + self.clients_to_model_rate_mapping = clients_to_model_rate_mapping + ans = self.model_rate_manager.create_model_rate_mapping( + len(clients_to_model_rate_mapping) + ) + # copy self.clients_to_model_rate_mapping , ans + for i, model_rate in enumerate(ans): + self.clients_to_model_rate_mapping[i] = model_rate + + # shall handle in case of not_simulation... + self.client_label_split = client_label_split + + self._cv = threading.Condition() + + def __len__(self) -> int: + """Return the length of clients Dict. + + Returns + ------- + len : int + Length of Dict (self.clients). + """ + return len(self.clients) + + def num_available(self) -> int: + """Return the number of available clients. + + Returns + ------- + num_available : int + The number of currently available clients. + """ + return len(self) + + def wait_for(self, num_clients: int, timeout: int = 86400) -> bool: + """Wait until at least `num_clients` are available. + + Blocks until the requested number of clients is available or until a + timeout is reached. Current timeout default: 1 day. + + Parameters + ---------- + num_clients : int + The number of clients to wait for. + timeout : int + The time in seconds to wait for, defaults to 86400 (24h). + + Returns + ------- + success : bool + """ + with self._cv: + return self._cv.wait_for( + lambda: len(self.clients) >= num_clients, timeout=timeout + ) + + def register(self, client: ClientProxy) -> bool: + """Register Flower ClientProxy instance. + + Parameters + ---------- + client : flwr.server.client_proxy.ClientProxy + + Returns + ------- + success : bool + Indicating if registration was successful. False if ClientProxy is + already registered or can not be registered for any reason. + """ + if client.cid in self.clients: + return False + + self.clients[client.cid] = client + + # in case of not a simulation, this type of method can be used + # if self.is_simulation is False: + # prop = client.get_properties(None, timeout=86400) + # self.clients_to_model_rate_mapping[int(client.cid)] = prop["model_rate"] + # self.client_label_split[int(client.cid)] = prop["label_split"] + + with self._cv: + self._cv.notify_all() + + return True + + def unregister(self, client: ClientProxy) -> None: + """Unregister Flower ClientProxy instance. + + This method is idempotent. + + Parameters + ---------- + client : flwr.server.client_proxy.ClientProxy + """ + if client.cid in self.clients: + del self.clients[client.cid] + + with self._cv: + self._cv.notify_all() + + def all(self) -> Dict[str, ClientProxy]: + """Return all available clients.""" + return self.clients + + def get_client_to_model_mapping(self, cid) -> float: + """Return model rate of client with cid.""" + return self.clients_to_model_rate_mapping[int(cid)] + + def get_all_clients_to_model_mapping(self) -> List[float]: + """Return all available clients to model rate mapping.""" + return self.clients_to_model_rate_mapping.copy() + + def update(self, server_round: int) -> None: + """Update the client to model rate mapping.""" + if self.is_simulation is True: + if ( + server_round == 1 and self.model_rate_manager.model_split_mode == "fix" + ) or (self.model_rate_manager.model_split_mode == "dynamic"): + ans = self.model_rate_manager.create_model_rate_mapping( + self.num_available() + ) + # copy self.clients_to_model_rate_mapping , ans + for i, model_rate in enumerate(ans): + self.clients_to_model_rate_mapping[i] = model_rate + print( + "clients to model rate mapping ", self.clients_to_model_rate_mapping + ) + return + + # to be handled in case of not a simulation, i.e. to get the properties + # again from the clients as they can change the model_rate + # for i in range(self.num_available): + # # need to test this , accumilates the + # # changing model rate of the client + # self.clients_to_model_rate_mapping[i] = + # self.clients[str(i)].get_properties['model_rate'] + # return + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + random_indices = torch.randperm(len(available_cids))[:num_clients] + # Use the random indices to select clients + sampled_cids = [available_cids[i] for i in random_indices] + sampled_cids = random.sample(available_cids, num_clients) + print(f"Sampled CIDS = {sampled_cids}") + return [self.clients[cid] for cid in sampled_cids] diff --git a/baselines/heterofl/heterofl/conf/base.yaml b/baselines/heterofl/heterofl/conf/base.yaml new file mode 100644 index 000000000000..42edf419cc38 --- /dev/null +++ b/baselines/heterofl/heterofl/conf/base.yaml @@ -0,0 +1,47 @@ +num_clients: 100 +num_epochs: 5 +num_rounds: 800 +seed: 0 +client_resources: + num_cpus: 1 + num_gpus: 0.08 + +control: + model_split_mode: 'dynamic' + model_mode: 'a1-b1-c1-d1-e1' + +dataset: + dataset_name: 'CIFAR10' + iid: False + shard_per_user : 2 # only used in case of non-iid (i.e. iid = false) + balance: false + batch_size: + train: 10 + test: 50 + shuffle: + train: true + test: false + + +model: + model_name: resnet18 # use 'conv' for MNIST + hidden_layers: [64 , 128 , 256 , 512] + norm: bn + scale: 1 + mask: 1 + + +optim_scheduler: + optimizer: SGD + lr: 0.1 + momentum: 0.9 + weight_decay: 5.00e-04 + scheduler: MultiStepLR + milestones: [300, 500] + +strategy: + _target_: heterofl.strategy.HeteroFL + fraction_fit: 0.1 + fraction_evaluate: 0.1 + min_fit_clients: 10 + min_evaluate_clients: 10 diff --git a/baselines/heterofl/heterofl/conf/fedavg.yaml b/baselines/heterofl/heterofl/conf/fedavg.yaml new file mode 100644 index 000000000000..d67d0950654a --- /dev/null +++ b/baselines/heterofl/heterofl/conf/fedavg.yaml @@ -0,0 +1,41 @@ +num_clients: 100 +num_epochs: 1 +num_rounds: 800 +seed: 0 +clip: False +enable_train_on_train_data_while_testing: False +client_resources: + num_cpus: 1 + num_gpus: 0.4 + +dataset: + dataset_name: 'MNIST' + iid: False + shard_per_user : 2 + balance: False + batch_size: + train: 10 + test: 10 + shuffle: + train: true + test: false + + +model: + model_name: MLP #use CNNCifar for CIFAR10 + +optim_scheduler: + optimizer: SGD + lr: 0.05 + lr_decay_rate: 1.0 + momentum: 0.5 + weight_decay: 0 + scheduler: MultiStepLR + milestones: [] + +strategy: + _target_: flwr.server.strategy.FedAvg + fraction_fit: 0.1 + fraction_evaluate: 0.1 + min_fit_clients: 10 + min_evaluate_clients: 10 diff --git a/baselines/heterofl/heterofl/dataset.py b/baselines/heterofl/heterofl/dataset.py new file mode 100644 index 000000000000..0e0f4b726842 --- /dev/null +++ b/baselines/heterofl/heterofl/dataset.py @@ -0,0 +1,83 @@ +"""Utilities for creation of DataLoaders for clients and server.""" + +from typing import List, Optional, Tuple + +import torch +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from heterofl.dataset_preparation import _partition_data + + +def load_datasets( # pylint: disable=too-many-arguments + strategy_name: str, + config: DictConfig, + num_clients: int, + seed: Optional[int] = 42, +) -> Tuple[ + DataLoader, List[DataLoader], List[torch.tensor], List[DataLoader], DataLoader +]: + """Create the dataloaders to be fed into the model. + + Parameters + ---------- + config: DictConfig + Parameterises the dataset partitioning process + num_clients : int + The number of clients that hold a part of the data + seed : int, optional + Used to set a fix seed to replicate experiments, by default 42 + + Returns + ------- + Tuple[DataLoader, DataLoader, DataLoader, DataLoader] + The entire trainset Dataloader for testing purposes, + The DataLoader for training, the DataLoader for validation, + the DataLoader for testing. + """ + print(f"Dataset partitioning config: {config}") + trainset, datasets, label_split, client_testsets, testset = _partition_data( + num_clients, + dataset_name=config.dataset_name, + strategy_name=strategy_name, + iid=config.iid, + dataset_division={ + "shard_per_user": config.shard_per_user, + "balance": config.balance, + }, + seed=seed, + ) + # Split each partition into train/val and create DataLoader + entire_trainloader = DataLoader( + trainset, batch_size=config.batch_size.train, shuffle=config.shuffle.train + ) + + trainloaders = [] + valloaders = [] + for dataset in datasets: + trainloaders.append( + DataLoader( + dataset, + batch_size=config.batch_size.train, + shuffle=config.shuffle.train, + ) + ) + + for client_testset in client_testsets: + valloaders.append( + DataLoader( + client_testset, + batch_size=config.batch_size.test, + shuffle=config.shuffle.test, + ) + ) + + return ( + entire_trainloader, + trainloaders, + label_split, + valloaders, + DataLoader( + testset, batch_size=config.batch_size.test, shuffle=config.shuffle.test + ), + ) diff --git a/baselines/heterofl/heterofl/dataset_preparation.py b/baselines/heterofl/heterofl/dataset_preparation.py new file mode 100644 index 000000000000..525e815e9e98 --- /dev/null +++ b/baselines/heterofl/heterofl/dataset_preparation.py @@ -0,0 +1,357 @@ +"""Functions for dataset download and processing.""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import ConcatDataset, Dataset, Subset, random_split +from torchvision import transforms + +import heterofl.datasets as dt + + +def _download_data(dataset_name: str, strategy_name: str) -> Tuple[Dataset, Dataset]: + root = "./data/{}".format(dataset_name) + if dataset_name == "MNIST": + trainset = dt.MNIST( + root=root, + split="train", + subset="label", + transform=dt.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + testset = dt.MNIST( + root=root, + split="test", + subset="label", + transform=dt.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + elif dataset_name == "CIFAR10": + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + if strategy_name == "heterofl": + normalize = transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ) + trainset = dt.CIFAR10( + root=root, + split="train", + subset="label", + transform=dt.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + testset = dt.CIFAR10( + root=root, + split="test", + subset="label", + transform=dt.Compose( + [ + transforms.ToTensor(), + normalize, + ] + ), + ) + else: + raise ValueError(f"{dataset_name} is not valid") + + return trainset, testset + + +# pylint: disable=too-many-arguments +def _partition_data( + num_clients: int, + dataset_name: str, + strategy_name: str, + iid: Optional[bool] = False, + dataset_division=None, + seed: Optional[int] = 42, +) -> Tuple[Dataset, List[Dataset], List[torch.tensor], List[Dataset], Dataset]: + trainset, testset = _download_data(dataset_name, strategy_name) + + if dataset_name in ("MNIST", "CIFAR10"): + classes_size = 10 + + if dataset_division["balance"]: + trainset = _balance_classes(trainset, seed) + + if iid: + datasets, label_split = iid_partition(trainset, num_clients, seed=seed) + client_testsets, _ = iid_partition(testset, num_clients, seed=seed) + else: + datasets, label_split = non_iid( + {"dataset": trainset, "classes_size": classes_size}, + num_clients, + dataset_division["shard_per_user"], + ) + client_testsets, _ = non_iid( + { + "dataset": testset, + "classes_size": classes_size, + }, + num_clients, + dataset_division["shard_per_user"], + label_split, + ) + + tensor_label_split = [] + for i in label_split: + tensor_label_split.append(torch.Tensor(i)) + label_split = tensor_label_split + + return trainset, datasets, label_split, client_testsets, testset + + +def iid_partition( + dataset: Dataset, num_clients: int, seed: Optional[int] = 42 +) -> Tuple[List[Dataset], List[torch.tensor]]: + """IID partition of dataset among clients.""" + partition_size = int(len(dataset) / num_clients) + lengths = [partition_size] * num_clients + + divided_dataset = random_split( + dataset, lengths, torch.Generator().manual_seed(seed) + ) + label_split = [] + for i in range(num_clients): + label_split.append( + torch.unique(torch.Tensor([target for _, target in divided_dataset[i]])) + ) + + return divided_dataset, label_split + + +def non_iid( + dataset_info, + num_clients: int, + shard_per_user: int, + label_split=None, + seed=42, +) -> Tuple[List[Dataset], List]: + """Non-IID partition of dataset among clients. + + Adopted from authors (of heterofl) implementation. + """ + data_split: Dict[int, List] = {i: [] for i in range(num_clients)} + + label_idx_split, shard_per_class = _split_dataset_targets_idx( + dataset_info["dataset"], + shard_per_user, + num_clients, + dataset_info["classes_size"], + ) + + if label_split is None: + label_split = list(range(dataset_info["classes_size"])) * shard_per_class + label_split = torch.tensor(label_split)[ + torch.randperm( + len(label_split), generator=torch.Generator().manual_seed(seed) + ) + ].tolist() + label_split = np.array(label_split).reshape((num_clients, -1)).tolist() + + for i, _ in enumerate(label_split): + label_split[i] = np.unique(label_split[i]).tolist() + + for i in range(num_clients): + for label_i in label_split[i]: + idx = torch.arange(len(label_idx_split[label_i]))[ + torch.randperm( + len(label_idx_split[label_i]), + generator=torch.Generator().manual_seed(seed), + )[0] + ].item() + data_split[i].extend(label_idx_split[label_i].pop(idx)) + + return ( + _get_dataset_from_idx(dataset_info["dataset"], data_split, num_clients), + label_split, + ) + + +def _split_dataset_targets_idx(dataset, shard_per_user, num_clients, classes_size): + label = np.array(dataset.target) if hasattr(dataset, "target") else dataset.targets + label_idx_split: Dict = {} + for i, _ in enumerate(label): + label_i = label[i].item() + if label_i not in label_idx_split: + label_idx_split[label_i] = [] + label_idx_split[label_i].append(i) + + shard_per_class = int(shard_per_user * num_clients / classes_size) + + for label_i in label_idx_split: + label_idx = label_idx_split[label_i] + num_leftover = len(label_idx) % shard_per_class + leftover = label_idx[-num_leftover:] if num_leftover > 0 else [] + new_label_idx = ( + np.array(label_idx[:-num_leftover]) + if num_leftover > 0 + else np.array(label_idx) + ) + new_label_idx = new_label_idx.reshape((shard_per_class, -1)).tolist() + + for i, leftover_label_idx in enumerate(leftover): + new_label_idx[i] = np.concatenate([new_label_idx[i], [leftover_label_idx]]) + label_idx_split[label_i] = new_label_idx + return label_idx_split, shard_per_class + + +def _get_dataset_from_idx(dataset, data_split, num_clients): + divided_dataset = [None for i in range(num_clients)] + for i in range(num_clients): + divided_dataset[i] = Subset(dataset, data_split[i]) + return divided_dataset + + +def _balance_classes( + trainset: Dataset, + seed: Optional[int] = 42, +) -> Dataset: + class_counts = np.bincount(trainset.target) + targets = torch.Tensor(trainset.target) + smallest = np.min(class_counts) + idxs = targets.argsort() + tmp = [Subset(trainset, idxs[: int(smallest)])] + tmp_targets = [targets[idxs[: int(smallest)]]] + for count in np.cumsum(class_counts): + tmp.append(Subset(trainset, idxs[int(count) : int(count + smallest)])) + tmp_targets.append(targets[idxs[int(count) : int(count + smallest)]]) + unshuffled = ConcatDataset(tmp) + unshuffled_targets = torch.cat(tmp_targets) + shuffled_idxs = torch.randperm( + len(unshuffled), generator=torch.Generator().manual_seed(seed) + ) + shuffled = Subset(unshuffled, shuffled_idxs) + shuffled.targets = unshuffled_targets[shuffled_idxs] + + return shuffled + + +def _sort_by_class( + trainset: Dataset, +) -> Dataset: + class_counts = np.bincount(trainset.targets) + idxs = trainset.targets.argsort() # sort targets in ascending order + + tmp = [] # create subset of smallest class + tmp_targets = [] # same for targets + + start = 0 + for count in np.cumsum(class_counts): + tmp.append( + Subset(trainset, idxs[start : int(count + start)]) + ) # add rest of classes + tmp_targets.append(trainset.targets[idxs[start : int(count + start)]]) + start += count + sorted_dataset = ConcatDataset(tmp) # concat dataset + sorted_dataset.targets = torch.cat(tmp_targets) # concat targets + return sorted_dataset + + +# pylint: disable=too-many-locals, too-many-arguments +def _power_law_split( + sorted_trainset: Dataset, + num_partitions: int, + num_labels_per_partition: int = 2, + min_data_per_partition: int = 10, + mean: float = 0.0, + sigma: float = 2.0, +) -> Dataset: + """Partition the dataset following a power-law distribution. It follows the. + + implementation of Li et al 2020: https://arxiv.org/abs/1812.06127 with default + values set accordingly. + + Parameters + ---------- + sorted_trainset : Dataset + The training dataset sorted by label/class. + num_partitions: int + Number of partitions to create + num_labels_per_partition: int + Number of labels to have in each dataset partition. For + example if set to two, this means all training examples in + a given partition will be long to the same two classes. default 2 + min_data_per_partition: int + Minimum number of datapoints included in each partition, default 10 + mean: float + Mean value for LogNormal distribution to construct power-law, default 0.0 + sigma: float + Sigma value for LogNormal distribution to construct power-law, default 2.0 + + Returns + ------- + Dataset + The partitioned training dataset. + """ + targets = sorted_trainset.targets + full_idx = list(range(len(targets))) + + class_counts = np.bincount(sorted_trainset.targets) + labels_cs = np.cumsum(class_counts) + labels_cs = [0] + labels_cs[:-1].tolist() + + partitions_idx: List[List[int]] = [] + num_classes = len(np.bincount(targets)) + hist = np.zeros(num_classes, dtype=np.int32) + + # assign min_data_per_partition + min_data_per_class = int(min_data_per_partition / num_labels_per_partition) + for u_id in range(num_partitions): + partitions_idx.append([]) + for cls_idx in range(num_labels_per_partition): + # label for the u_id-th client + cls = (u_id + cls_idx) % num_classes + # record minimum data + indices = list( + full_idx[ + labels_cs[cls] + + hist[cls] : labels_cs[cls] + + hist[cls] + + min_data_per_class + ] + ) + partitions_idx[-1].extend(indices) + hist[cls] += min_data_per_class + + # add remaining images following power-law + probs = np.random.lognormal( + mean, + sigma, + (num_classes, int(num_partitions / num_classes), num_labels_per_partition), + ) + remaining_per_class = class_counts - hist + # obtain how many samples each partition should be assigned for each of the + # labels it contains + # pylint: disable=too-many-function-args + probs = ( + remaining_per_class.reshape(-1, 1, 1) + * probs + / np.sum(probs, (1, 2), keepdims=True) + ) + + for u_id in range(num_partitions): + for cls_idx in range(num_labels_per_partition): + cls = (u_id + cls_idx) % num_classes + count = int(probs[cls, u_id // num_classes, cls_idx]) + + # add count of specific class to partition + indices = full_idx[ + labels_cs[cls] + hist[cls] : labels_cs[cls] + hist[cls] + count + ] + partitions_idx[u_id].extend(indices) + hist[cls] += count + + # construct subsets + partitions = [Subset(sorted_trainset, p) for p in partitions_idx] + return partitions diff --git a/baselines/heterofl/heterofl/datasets/__init__.py b/baselines/heterofl/heterofl/datasets/__init__.py new file mode 100644 index 000000000000..91251db77302 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/__init__.py @@ -0,0 +1,9 @@ +"""Dataset module. + +The entire datasets module is adopted from authors implementation. +""" +from .cifar import CIFAR10 +from .mnist import MNIST +from .utils import Compose + +__all__ = ("MNIST", "CIFAR10", "Compose") diff --git a/baselines/heterofl/heterofl/datasets/cifar.py b/baselines/heterofl/heterofl/datasets/cifar.py new file mode 100644 index 000000000000..c75194bc8ee7 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/cifar.py @@ -0,0 +1,150 @@ +"""CIFAR10 dataset class, adopted from authors implementation.""" +import os +import pickle + +import anytree +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from heterofl.datasets.utils import ( + download_url, + extract_file, + make_classes_counts, + make_flat_index, + make_tree, +) +from heterofl.utils import check_exists, load, makedir_exist_ok, save + + +# pylint: disable=too-many-instance-attributes +class CIFAR10(Dataset): + """CIFAR10 dataset.""" + + data_name = "CIFAR10" + file = [ + ( + "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", + "c58f30108f718f92721af3b95e74349a", + ) + ] + + def __init__(self, root, split, subset, transform=None): + self.root = os.path.expanduser(root) + self.split = split + self.subset = subset + self.transform = transform + if not check_exists(self.processed_folder): + self.process() + self.img, self.target = load( + os.path.join(self.processed_folder, "{}.pt".format(self.split)) + ) + self.target = self.target[self.subset] + self.classes_counts = make_classes_counts(self.target) + self.classes_to_labels, self.classes_size = load( + os.path.join(self.processed_folder, "meta.pt") + ) + self.classes_to_labels, self.classes_size = ( + self.classes_to_labels[self.subset], + self.classes_size[self.subset], + ) + + def __getitem__(self, index): + """Get the item with index.""" + img, target = Image.fromarray(self.img[index]), torch.tensor(self.target[index]) + inp = {"img": img, self.subset: target} + if self.transform is not None: + inp = self.transform(inp) + return inp["img"], inp["label"] + + def __len__(self): + """Length of the dataset.""" + return len(self.img) + + @property + def processed_folder(self): + """Return path of processed folder.""" + return os.path.join(self.root, "processed") + + @property + def raw_folder(self): + """Return path of raw folder.""" + return os.path.join(self.root, "raw") + + def process(self): + """Save the dataset accordingly.""" + if not check_exists(self.raw_folder): + self.download() + train_set, test_set, meta = self.make_data() + save(train_set, os.path.join(self.processed_folder, "train.pt")) + save(test_set, os.path.join(self.processed_folder, "test.pt")) + save(meta, os.path.join(self.processed_folder, "meta.pt")) + + def download(self): + """Download dataset from the url.""" + makedir_exist_ok(self.raw_folder) + for url, md5 in self.file: + filename = os.path.basename(url) + download_url(url, self.raw_folder, filename, md5) + extract_file(os.path.join(self.raw_folder, filename)) + + def __repr__(self): + """Represent CIFAR10 as string.""" + fmt_str = ( + f"Dataset {self.__class__.__name__}\nSize: {self.__len__()}\n" + f"Root: {self.root}\nSplit: {self.split}\nSubset: {self.subset}\n" + f"Transforms: {self.transform.__repr__()}" + ) + return fmt_str + + def make_data(self): + """Make data.""" + train_filenames = [ + "data_batch_1", + "data_batch_2", + "data_batch_3", + "data_batch_4", + "data_batch_5", + ] + test_filenames = ["test_batch"] + train_img, train_label = _read_pickle_file( + os.path.join(self.raw_folder, "cifar-10-batches-py"), train_filenames + ) + test_img, test_label = _read_pickle_file( + os.path.join(self.raw_folder, "cifar-10-batches-py"), test_filenames + ) + train_target, test_target = {"label": train_label}, {"label": test_label} + with open( + os.path.join(self.raw_folder, "cifar-10-batches-py", "batches.meta"), "rb" + ) as fle: + data = pickle.load(fle, encoding="latin1") + classes = data["label_names"] + classes_to_labels = {"label": anytree.Node("U", index=[])} + for cls in classes: + make_tree(classes_to_labels["label"], [cls]) + classes_size = {"label": make_flat_index(classes_to_labels["label"])} + return ( + (train_img, train_target), + (test_img, test_target), + (classes_to_labels, classes_size), + ) + + +def _read_pickle_file(path, filenames): + img, label = [], [] + for filename in filenames: + file_path = os.path.join(path, filename) + with open(file_path, "rb") as file: + entry = pickle.load(file, encoding="latin1") + img.append(entry["data"]) + if "labels" in entry: + label.extend(entry["labels"]) + else: + label.extend(entry["fine_labels"]) + # label.extend(entry["labels"]) if "labels" in entry else label.extend( + # entry["fine_labels"] + # ) + img = np.vstack(img).reshape(-1, 3, 32, 32) + img = img.transpose((0, 2, 3, 1)) + return img, label diff --git a/baselines/heterofl/heterofl/datasets/mnist.py b/baselines/heterofl/heterofl/datasets/mnist.py new file mode 100644 index 000000000000..feae2ea987b4 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/mnist.py @@ -0,0 +1,167 @@ +"""MNIST dataset class, adopted from authors implementation.""" +import codecs +import os + +import anytree +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from heterofl.datasets.utils import ( + download_url, + extract_file, + make_classes_counts, + make_flat_index, + make_tree, +) +from heterofl.utils import check_exists, load, makedir_exist_ok, save + + +# pylint: disable=too-many-instance-attributes +class MNIST(Dataset): + """MNIST dataset.""" + + data_name = "MNIST" + file = [ + ( + "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", + "f68b3c2dcbeaaa9fbdd348bbdeb94873", + ), + ( + "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", + "9fb629c4189551a2d022fa330f9573f3", + ), + ( + "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", + "d53e105ee54ea40749a09fcbcd1e9432", + ), + ( + "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", + "ec29112dd5afa0611ce80d1b7f02629c", + ), + ] + + def __init__(self, root, split, subset, transform=None): + self.root = os.path.expanduser(root) + self.split = split + self.subset = subset + self.transform = transform + if not check_exists(self.processed_folder): + self.process() + self.img, self.target = load( + os.path.join(self.processed_folder, "{}.pt".format(self.split)) + ) + self.target = self.target[self.subset] + self.classes_counts = make_classes_counts(self.target) + self.classes_to_labels, self.classes_size = load( + os.path.join(self.processed_folder, "meta.pt") + ) + self.classes_to_labels, self.classes_size = ( + self.classes_to_labels[self.subset], + self.classes_size[self.subset], + ) + + def __getitem__(self, index): + """Get the item with index.""" + img, target = Image.fromarray(self.img[index]), torch.tensor(self.target[index]) + inp = {"img": img, self.subset: target} + if self.transform is not None: + inp = self.transform(inp) + return inp["img"], inp["label"] + + def __len__(self): + """Length of the dataset.""" + return len(self.img) + + @property + def processed_folder(self): + """Return path of processed folder.""" + return os.path.join(self.root, "processed") + + @property + def raw_folder(self): + """Return path of raw folder.""" + return os.path.join(self.root, "raw") + + def process(self): + """Save the dataset accordingly.""" + if not check_exists(self.raw_folder): + self.download() + train_set, test_set, meta = self.make_data() + save(train_set, os.path.join(self.processed_folder, "train.pt")) + save(test_set, os.path.join(self.processed_folder, "test.pt")) + save(meta, os.path.join(self.processed_folder, "meta.pt")) + + def download(self): + """Download and save the dataset accordingly.""" + makedir_exist_ok(self.raw_folder) + for url, md5 in self.file: + filename = os.path.basename(url) + download_url(url, self.raw_folder, filename, md5) + extract_file(os.path.join(self.raw_folder, filename)) + + def __repr__(self): + """Represent CIFAR10 as string.""" + fmt_str = ( + f"Dataset {self.__class__.__name__}\nSize: {self.__len__()}\n" + f"Root: {self.root}\nSplit: {self.split}\nSubset: {self.subset}\n" + f"Transforms: {self.transform.__repr__()}" + ) + return fmt_str + + def make_data(self): + """Make data.""" + train_img = _read_image_file( + os.path.join(self.raw_folder, "train-images-idx3-ubyte") + ) + test_img = _read_image_file( + os.path.join(self.raw_folder, "t10k-images-idx3-ubyte") + ) + train_label = _read_label_file( + os.path.join(self.raw_folder, "train-labels-idx1-ubyte") + ) + test_label = _read_label_file( + os.path.join(self.raw_folder, "t10k-labels-idx1-ubyte") + ) + train_target, test_target = {"label": train_label}, {"label": test_label} + classes_to_labels = {"label": anytree.Node("U", index=[])} + classes = list(map(str, list(range(10)))) + for cls in classes: + make_tree(classes_to_labels["label"], [cls]) + classes_size = {"label": make_flat_index(classes_to_labels["label"])} + return ( + (train_img, train_target), + (test_img, test_target), + (classes_to_labels, classes_size), + ) + + +def _get_int(num): + return int(codecs.encode(num, "hex"), 16) + + +def _read_image_file(path): + with open(path, "rb") as file: + data = file.read() + assert _get_int(data[:4]) == 2051 + length = _get_int(data[4:8]) + num_rows = _get_int(data[8:12]) + num_cols = _get_int(data[12:16]) + parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape( + (length, num_rows, num_cols) + ) + return parsed + + +def _read_label_file(path): + with open(path, "rb") as file: + data = file.read() + assert _get_int(data[:4]) == 2049 + length = _get_int(data[4:8]) + parsed = ( + np.frombuffer(data, dtype=np.uint8, offset=8) + .reshape(length) + .astype(np.int64) + ) + return parsed diff --git a/baselines/heterofl/heterofl/datasets/utils.py b/baselines/heterofl/heterofl/datasets/utils.py new file mode 100644 index 000000000000..6b71811ed50d --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/utils.py @@ -0,0 +1,244 @@ +"""Contains utility functions required for datasests. + +Adopted from authors implementation. +""" +import glob +import gzip +import hashlib +import os +import tarfile +import zipfile +from collections import Counter + +import anytree +import numpy as np +from PIL import Image +from six.moves import urllib +from tqdm import tqdm + +from heterofl.utils import makedir_exist_ok + +IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif"] + + +def find_classes(drctry): + """Find the classes in a directory.""" + classes = [d.name for d in os.scandir(drctry) if d.is_dir()] + classes.sort() + classes_to_labels = {classes[i]: i for i in range(len(classes))} + return classes_to_labels + + +def pil_loader(path): + """Load image from path using PIL.""" + with open(path, "rb") as file: + img = Image.open(file) + return img.convert("RGB") + + +# def accimage_loader(path): +# """Load image from path using accimage_loader.""" +# import accimage + +# try: +# return accimage.Image(path) +# except IOError: +# return pil_loader(path) + + +def default_loader(path): + """Load image from path using default loader.""" + # if get_image_backend() == "accimage": + # return accimage_loader(path) + + return pil_loader(path) + + +def has_file_allowed_extension(filename, extensions): + """Check whether file possesses any of the extensions listed.""" + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in extensions) + + +def make_classes_counts(label): + """Count number of classes.""" + label = np.array(label) + if label.ndim > 1: + label = label.sum(axis=tuple(range(1, label.ndim))) + classes_counts = Counter(label) + return classes_counts + + +def _make_bar_updater(pbar): + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def _calculate_md5(path, chunk_size=1024 * 1024): + md5 = hashlib.md5() + with open(path, "rb") as file: + for chunk in iter(lambda: file.read(chunk_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def _check_md5(path, md5, **kwargs): + return md5 == _calculate_md5(path, **kwargs) + + +def _check_integrity(path, md5=None): + if not os.path.isfile(path): + return False + if md5 is None: + return True + return _check_md5(path, md5) + + +def download_url(url, root, filename, md5): + """Download files from the url.""" + path = os.path.join(root, filename) + makedir_exist_ok(root) + if os.path.isfile(path) and _check_integrity(path, md5): + print("Using downloaded and verified file: " + path) + else: + try: + print("Downloading " + url + " to " + path) + urllib.request.urlretrieve( + url, path, reporthook=_make_bar_updater(tqdm(unit="B", unit_scale=True)) + ) + except OSError: + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + path + ) + urllib.request.urlretrieve( + url, + path, + reporthook=_make_bar_updater(tqdm(unit="B", unit_scale=True)), + ) + if not _check_integrity(path, md5): + raise RuntimeError("Not valid downloaded file") + + +def extract_file(src, dest=None, delete=False): + """Extract the file.""" + print("Extracting {}".format(src)) + dest = os.path.dirname(src) if dest is None else dest + filename = os.path.basename(src) + if filename.endswith(".zip"): + with zipfile.ZipFile(src, "r") as zip_f: + zip_f.extractall(dest) + elif filename.endswith(".tar"): + with tarfile.open(src) as tar_f: + tar_f.extractall(dest) + elif filename.endswith(".tar.gz") or filename.endswith(".tgz"): + with tarfile.open(src, "r:gz") as tar_f: + tar_f.extractall(dest) + elif filename.endswith(".gz"): + with open(src.replace(".gz", ""), "wb") as out_f, gzip.GzipFile(src) as zip_f: + out_f.write(zip_f.read()) + if delete: + os.remove(src) + + +def make_data(root, extensions): + """Get all the files in the root directory that follows the given extensions.""" + path = [] + files = glob.glob("{}/**/*".format(root), recursive=True) + for file in files: + if has_file_allowed_extension(file, extensions): + path.append(os.path.normpath(file)) + return path + + +# pylint: disable=dangerous-default-value +def make_img(path, classes_to_labels, extensions=IMG_EXTENSIONS): + """Make image.""" + img, label = [], [] + classes = [] + leaf_nodes = classes_to_labels.leaves + for node in leaf_nodes: + classes.append(node.name) + for cls in sorted(classes): + folder = os.path.join(path, cls) + if not os.path.isdir(folder): + continue + for root, _, filenames in sorted(os.walk(folder)): + for filename in sorted(filenames): + if has_file_allowed_extension(filename, extensions): + cur_path = os.path.join(root, filename) + img.append(cur_path) + label.append( + anytree.find_by_attr(classes_to_labels, cls).flat_index + ) + return img, label + + +def make_tree(root, name, attribute=None): + """Create a tree of name.""" + if len(name) == 0: + return + if attribute is None: + attribute = {} + this_name = name[0] + next_name = name[1:] + this_attribute = {k: attribute[k][0] for k in attribute} + next_attribute = {k: attribute[k][1:] for k in attribute} + this_node = anytree.find_by_attr(root, this_name) + this_index = root.index + [len(root.children)] + if this_node is None: + this_node = anytree.Node( + this_name, parent=root, index=this_index, **this_attribute + ) + make_tree(this_node, next_name, next_attribute) + return + + +def make_flat_index(root, given=None): + """Make flat index for each leaf node in the tree.""" + if given: + classes_size = 0 + for node in anytree.PreOrderIter(root): + if len(node.children) == 0: + node.flat_index = given.index(node.name) + classes_size = ( + given.index(node.name) + 1 + if given.index(node.name) + 1 > classes_size + else classes_size + ) + else: + classes_size = 0 + for node in anytree.PreOrderIter(root): + if len(node.children) == 0: + node.flat_index = classes_size + classes_size += 1 + return classes_size + + +class Compose: + """Custom Compose class.""" + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, inp): + """Apply transforms when called.""" + for transform in self.transforms: + inp["img"] = transform(inp["img"]) + return inp + + def __repr__(self): + """Represent Compose as string.""" + format_string = self.__class__.__name__ + "(" + for transform in self.transforms: + format_string += "\n" + format_string += " {0}".format(transform) + format_string += "\n)" + return format_string diff --git a/baselines/heterofl/heterofl/main.py b/baselines/heterofl/heterofl/main.py new file mode 100644 index 000000000000..3973841cb60e --- /dev/null +++ b/baselines/heterofl/heterofl/main.py @@ -0,0 +1,204 @@ +"""Runs federated learning for given configuration in base.yaml.""" +import pickle +from pathlib import Path + +import flwr as fl +import hydra +import torch +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from heterofl import client, models, server +from heterofl.client_manager_heterofl import ClientManagerHeteroFL +from heterofl.dataset import load_datasets +from heterofl.model_properties import get_model_properties +from heterofl.utils import ModelRateManager, get_global_model_rate, preprocess_input + + +# pylint: disable=too-many-locals,protected-access +@hydra.main(config_path="conf", config_name="base.yaml", version_base=None) +def main(cfg: DictConfig) -> None: + """Run the baseline. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + # print config structured as YAML + print(OmegaConf.to_yaml(cfg)) + torch.manual_seed(cfg.seed) + + data_loaders = {} + + ( + data_loaders["entire_trainloader"], + data_loaders["trainloaders"], + data_loaders["label_split"], + data_loaders["valloaders"], + data_loaders["testloader"], + ) = load_datasets( + "heterofl" if "heterofl" in cfg.strategy._target_ else "fedavg", + config=cfg.dataset, + num_clients=cfg.num_clients, + seed=cfg.seed, + ) + + model_config = preprocess_input(cfg.model, cfg.dataset) + + model_split_rate = None + model_mode = None + client_to_model_rate_mapping = None + model_rate_manager = None + history = None + + if "HeteroFL" in cfg.strategy._target_: + # send this array(client_model_rate_mapping) as + # an argument to client_manager and client + model_split_rate = {"a": 1, "b": 0.5, "c": 0.25, "d": 0.125, "e": 0.0625} + # model_split_mode = cfg.control.model_split_mode + model_mode = cfg.control.model_mode + + client_to_model_rate_mapping = [float(0) for _ in range(cfg.num_clients)] + model_rate_manager = ModelRateManager( + cfg.control.model_split_mode, model_split_rate, model_mode + ) + + model_config["global_model_rate"] = model_split_rate[ + get_global_model_rate(model_mode) + ] + + test_model = models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + track=True, + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + ) + + get_model_properties( + model_config, + model_split_rate, + model_mode + "" if model_mode is not None else None, + data_loaders["entire_trainloader"], + cfg.dataset.batch_size.train, + ) + + # prepare function that will be used to spawn each client + client_train_settings = { + "epochs": cfg.num_epochs, + "optimizer": cfg.optim_scheduler.optimizer, + "lr": cfg.optim_scheduler.lr, + "momentum": cfg.optim_scheduler.momentum, + "weight_decay": cfg.optim_scheduler.weight_decay, + "scheduler": cfg.optim_scheduler.scheduler, + "milestones": cfg.optim_scheduler.milestones, + } + + if "clip" in cfg: + client_train_settings["clip"] = cfg.clip + + optim_scheduler_settings = { + "optimizer": cfg.optim_scheduler.optimizer, + "lr": cfg.optim_scheduler.lr, + "momentum": cfg.optim_scheduler.momentum, + "weight_decay": cfg.optim_scheduler.weight_decay, + "scheduler": cfg.optim_scheduler.scheduler, + "milestones": cfg.optim_scheduler.milestones, + } + + client_fn = client.gen_client_fn( + model_config=model_config, + client_to_model_rate_mapping=client_to_model_rate_mapping, + client_train_settings=client_train_settings, + data_loaders=data_loaders, + ) + + evaluate_fn = server.gen_evaluate_fn( + data_loaders, + torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + test_model, + models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + track=False, + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + ) + .state_dict() + .keys(), + enable_train_on_train_data=cfg.enable_train_on_train_data_while_testing + if "enable_train_on_train_data_while_testing" in cfg + else True, + ) + client_resources = { + "num_cpus": cfg.client_resources.num_cpus, + "num_gpus": cfg.client_resources.num_gpus if torch.cuda.is_available() else 0, + } + + if "HeteroFL" in cfg.strategy._target_: + strategy_heterofl = instantiate( + cfg.strategy, + model_name=cfg.model.model_name, + net=models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + device="cpu", + ), + optim_scheduler_settings=optim_scheduler_settings, + global_model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else 1.0, + evaluate_fn=evaluate_fn, + min_available_clients=cfg.num_clients, + ) + + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources=client_resources, + client_manager=ClientManagerHeteroFL( + model_rate_manager, + client_to_model_rate_mapping, + client_label_split=data_loaders["label_split"], + ), + strategy=strategy_heterofl, + ) + else: + strategy_fedavg = instantiate( + cfg.strategy, + # on_fit_config_fn=lambda server_round: { + # "lr": cfg.optim_scheduler.lr + # * pow(cfg.optim_scheduler.lr_decay_rate, server_round) + # }, + evaluate_fn=evaluate_fn, + min_available_clients=cfg.num_clients, + ) + + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources=client_resources, + strategy=strategy_fedavg, + ) + + # save the results + save_path = HydraConfig.get().runtime.output_dir + + # save the results as a python pickle + with open(str(Path(save_path) / "results.pkl"), "wb") as file_handle: + pickle.dump({"history": history}, file_handle, protocol=pickle.HIGHEST_PROTOCOL) + + # save the model + torch.save(test_model.state_dict(), str(Path(save_path) / "model.pth")) + + +if __name__ == "__main__": + main() diff --git a/baselines/heterofl/heterofl/model_properties.py b/baselines/heterofl/heterofl/model_properties.py new file mode 100644 index 000000000000..0739fe4fde22 --- /dev/null +++ b/baselines/heterofl/heterofl/model_properties.py @@ -0,0 +1,123 @@ +"""Determine number of model parameters, space it requires.""" +import numpy as np +import torch +import torch.nn as nn + +from heterofl.models import create_model + + +def get_model_properties( + model_config, model_split_rate, model_mode, data_loader, batch_size +): + """Calculate space occupied & number of parameters of model.""" + model_mode = model_mode.split("-") if model_mode is not None else None + # model = create_model(model_config, model_rate=model_split_rate(i[0])) + + total_flops = 0 + total_model_parameters = 0 + ttl_prcntg = 0 + if model_mode is None: + total_flops = _calculate_model_memory(create_model(model_config), data_loader) + total_model_parameters = _count_parameters(create_model(model_config)) + else: + for i in model_mode: + total_flops += _calculate_model_memory( + create_model(model_config, model_rate=model_split_rate[i[0]]), + data_loader, + ) * int(i[1]) + total_model_parameters += _count_parameters( + create_model(model_config, model_rate=model_split_rate[i[0]]) + ) * int(i[1]) + ttl_prcntg += int(i[1]) + + total_flops = total_flops / ttl_prcntg if ttl_prcntg != 0 else total_flops + total_flops /= batch_size + total_model_parameters = ( + total_model_parameters / ttl_prcntg + if ttl_prcntg != 0 + else total_model_parameters + ) + + space = total_model_parameters * 32.0 / 8 / (1024**2.0) + print("num_of_parameters = ", total_model_parameters / 1000, " K") + print("total_flops = ", total_flops / 1000000, " M") + print("space = ", space) + + return total_model_parameters, total_flops, space + + +def _calculate_model_memory(model, data_loader): + def register_hook(module): + def hook(module, inp, output): + # temp = _make_flops(module, inp, output) + # print(temp) + for _ in module.named_parameters(): + flops.append(_make_flops(module, inp, output)) + + if ( + not isinstance(module, nn.Sequential) + and not isinstance(module, nn.ModuleList) + and not isinstance(module, nn.ModuleDict) + and module != model + ): + hooks.append(module.register_forward_hook(hook)) + + hooks = [] + flops = [] + model.apply(register_hook) + + one_dl = next(iter(data_loader)) + input_dict = {"img": one_dl[0], "label": one_dl[1]} + with torch.no_grad(): + model(input_dict) + + for hook in hooks: + hook.remove() + + return sum(fl for fl in flops) + + +def _count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def _make_flops(module, inp, output): + if isinstance(inp, tuple): + return _make_flops(module, inp[0], output) + if isinstance(output, tuple): + return _make_flops(module, inp, output[0]) + flops = _compute_flops(module, inp, output) + return flops + + +def _compute_flops(module, inp, out): + flops = 0 + if isinstance(module, nn.Conv2d): + flops = _compute_conv2d_flops(module, inp, out) + elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)): + flops = np.prod(inp.shape).item() + if isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)) and module.affine: + flops *= 2 + elif isinstance(module, nn.Linear): + flops = np.prod(inp.size()[:-1]).item() * inp.size()[-1] * out.size()[-1] + # else: + # print(f"[Flops]: {type(module).__name__} is not supported!") + return flops + + +def _compute_conv2d_flops(module, inp, out): + batch_size = inp.size()[0] + in_c = inp.size()[1] + out_c, out_h, out_w = out.size()[1:] + groups = module.groups + filters_per_channel = out_c // groups + conv_per_position_flops = ( + module.kernel_size[0] * module.kernel_size[1] * in_c * filters_per_channel + ) + active_elements_count = batch_size * out_h * out_w + total_conv_flops = conv_per_position_flops * active_elements_count + bias_flops = 0 + if module.bias is not None: + bias_flops = out_c * active_elements_count + total_flops = total_conv_flops + bias_flops + return total_flops diff --git a/baselines/heterofl/heterofl/models.py b/baselines/heterofl/heterofl/models.py new file mode 100644 index 000000000000..9426ee8b2789 --- /dev/null +++ b/baselines/heterofl/heterofl/models.py @@ -0,0 +1,839 @@ +"""Conv & resnet18 model architecture, training, testing functions. + +Classes Conv, Block, Resnet18 are adopted from authors implementation. +""" +import copy +from typing import List, OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +from flwr.common import parameters_to_ndarrays +from torch import nn + +from heterofl.utils import make_optimizer + + +class Conv(nn.Module): + """Convolutional Neural Network architecture with sBN.""" + + def __init__( + self, + model_config, + ): + super().__init__() + self.model_config = model_config + + blocks = [ + nn.Conv2d( + model_config["data_shape"][0], model_config["hidden_size"][0], 3, 1, 1 + ), + self._get_scale(), + self._get_norm(0), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ] + for i in range(len(model_config["hidden_size"]) - 1): + blocks.extend( + [ + nn.Conv2d( + model_config["hidden_size"][i], + model_config["hidden_size"][i + 1], + 3, + 1, + 1, + ), + self._get_scale(), + self._get_norm(i + 1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ] + ) + blocks = blocks[:-1] + blocks.extend( + [ + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear( + model_config["hidden_size"][-1], model_config["classes_size"] + ), + ] + ) + self.blocks = nn.Sequential(*blocks) + + def _get_norm(self, j: int): + """Return the relavant norm.""" + if self.model_config["norm"] == "bn": + norm = nn.BatchNorm2d( + self.model_config["hidden_size"][j], + momentum=None, + track_running_stats=self.model_config["track"], + ) + elif self.model_config["norm"] == "in": + norm = nn.GroupNorm( + self.model_config["hidden_size"][j], self.model_config["hidden_size"][j] + ) + elif self.model_config["norm"] == "ln": + norm = nn.GroupNorm(1, self.model_config["hidden_size"][j]) + elif self.model_config["norm"] == "gn": + norm = nn.GroupNorm(4, self.model_config["hidden_size"][j]) + elif self.model_config["norm"] == "none": + norm = nn.Identity() + else: + raise ValueError("Not valid norm") + + return norm + + def _get_scale(self): + """Return the relavant scaler.""" + if self.model_config["scale"]: + scaler = _Scaler(self.model_config["rate"]) + else: + scaler = nn.Identity() + return scaler + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + # output = {"loss": torch.tensor(0, device=self.device, dtype=torch.float32)} + output = {} + out = self.blocks(input_dict["img"]) + if "label_split" in input_dict and self.model_config["mask"]: + label_mask = torch.zeros( + self.model_config["classes_size"], device=out.device + ) + label_mask[input_dict["label_split"]] = 1 + out = out.masked_fill(label_mask == 0, 0) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +def conv( + model_rate, + model_config, + device="cpu", +): + """Create the Conv model.""" + model_config["hidden_size"] = [ + int(np.ceil(model_rate * x)) for x in model_config["hidden_layers"] + ] + scaler_rate = model_rate / model_config["global_model_rate"] + model_config["rate"] = scaler_rate + model = Conv(model_config) + model.apply(_init_param) + return model.to(device) + + +class Block(nn.Module): + """Block.""" + + expansion = 1 + + def __init__(self, in_planes, planes, stride, model_config): + super().__init__() + if model_config["norm"] == "bn": + n_1 = nn.BatchNorm2d( + in_planes, momentum=None, track_running_stats=model_config["track"] + ) + n_2 = nn.BatchNorm2d( + planes, momentum=None, track_running_stats=model_config["track"] + ) + elif model_config["norm"] == "in": + n_1 = nn.GroupNorm(in_planes, in_planes) + n_2 = nn.GroupNorm(planes, planes) + elif model_config["norm"] == "ln": + n_1 = nn.GroupNorm(1, in_planes) + n_2 = nn.GroupNorm(1, planes) + elif model_config["norm"] == "gn": + n_1 = nn.GroupNorm(4, in_planes) + n_2 = nn.GroupNorm(4, planes) + elif model_config["norm"] == "none": + n_1 = nn.Identity() + n_2 = nn.Identity() + else: + raise ValueError("Not valid norm") + self.n_1 = n_1 + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.n_2 = n_2 + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) + if model_config["scale"]: + self.scaler = _Scaler(model_config["rate"]) + else: + self.scaler = nn.Identity() + + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ) + + def forward(self, x): + """Forward pass of the Block. + + Parameters + ---------- + x : Dict + Dict that contains Input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + out = F.relu(self.n_1(self.scaler(x))) + shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x + out = self.conv1(out) + out = self.conv2(F.relu(self.n_2(self.scaler(out)))) + out += shortcut + return out + + +# pylint: disable=too-many-instance-attributes +class ResNet(nn.Module): + """Implementation of a Residual Neural Network (ResNet) model with sBN.""" + + def __init__( + self, + model_config, + block, + num_blocks, + ): + self.model_config = model_config + super().__init__() + self.in_planes = model_config["hidden_size"][0] + self.conv1 = nn.Conv2d( + model_config["data_shape"][0], + model_config["hidden_size"][0], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + + self.layer1 = self._make_layer( + block, + model_config["hidden_size"][0], + num_blocks[0], + stride=1, + ) + self.layer2 = self._make_layer( + block, + model_config["hidden_size"][1], + num_blocks[1], + stride=2, + ) + self.layer3 = self._make_layer( + block, + model_config["hidden_size"][2], + num_blocks[2], + stride=2, + ) + self.layer4 = self._make_layer( + block, + model_config["hidden_size"][3], + num_blocks[3], + stride=2, + ) + + # self.layers = [layer1, layer2, layer3, layer4] + + if model_config["norm"] == "bn": + n_4 = nn.BatchNorm2d( + model_config["hidden_size"][3] * block.expansion, + momentum=None, + track_running_stats=model_config["track"], + ) + elif model_config["norm"] == "in": + n_4 = nn.GroupNorm( + model_config["hidden_size"][3] * block.expansion, + model_config["hidden_size"][3] * block.expansion, + ) + elif model_config["norm"] == "ln": + n_4 = nn.GroupNorm(1, model_config["hidden_size"][3] * block.expansion) + elif model_config["norm"] == "gn": + n_4 = nn.GroupNorm(4, model_config["hidden_size"][3] * block.expansion) + elif model_config["norm"] == "none": + n_4 = nn.Identity() + else: + raise ValueError("Not valid norm") + self.n_4 = n_4 + if model_config["scale"]: + self.scaler = _Scaler(model_config["rate"]) + else: + self.scaler = nn.Identity() + self.linear = nn.Linear( + model_config["hidden_size"][3] * block.expansion, + model_config["classes_size"], + ) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for strd in strides: + layers.append(block(self.in_planes, planes, strd, self.model_config.copy())) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, input_dict): + """Forward pass of the ResNet. + + Parameters + ---------- + input_dict : Dict + Dict that contains Input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + out = self.conv1(x) + # for layer in self.layers: + # out = layer(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.relu(self.n_4(self.scaler(out))) + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(out.size(0), -1) + out = self.linear(out) + if "label_split" in input_dict and self.model_config["mask"]: + label_mask = torch.zeros( + self.model_config["classes_size"], device=out.device + ) + label_mask[input_dict["label_split"]] = 1 + out = out.masked_fill(label_mask == 0, 0) + output["score"] = out + output["loss"] = F.cross_entropy(output["score"], input_dict["label"]) + return output + + +def resnet18( + model_rate, + model_config, + device="cpu", +): + """Create the ResNet18 model.""" + model_config["hidden_size"] = [ + int(np.ceil(model_rate * x)) for x in model_config["hidden_layers"] + ] + scaler_rate = model_rate / model_config["global_model_rate"] + model_config["rate"] = scaler_rate + model = ResNet(model_config, block=Block, num_blocks=[1, 1, 1, 2]) + model.apply(_init_param) + return model.to(device) + + +class MLP(nn.Module): + """Multi Layer Perceptron.""" + + def __init__(self): + super().__init__() + self.layer_input = nn.Linear(784, 512) + self.relu = nn.ReLU() + self.dropout = nn.Dropout() + self.layer_hidden1 = nn.Linear(512, 256) + self.layer_hidden2 = nn.Linear(256, 256) + self.layer_hidden3 = nn.Linear(256, 128) + self.layer_out = nn.Linear(128, 10) + self.softmax = nn.Softmax(dim=1) + self.weight_keys = [ + ["layer_input.weight", "layer_input.bias"], + ["layer_hidden1.weight", "layer_hidden1.bias"], + ["layer_hidden2.weight", "layer_hidden2.bias"], + ["layer_hidden3.weight", "layer_hidden3.bias"], + ["layer_out.weight", "layer_out.bias"], + ] + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1]) + x = self.layer_input(x) + x = self.relu(x) + + x = self.layer_hidden1(x) + x = self.relu(x) + + x = self.layer_hidden2(x) + x = self.relu(x) + + x = self.layer_hidden3(x) + x = self.relu(x) + + x = self.layer_out(x) + out = self.softmax(x) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +class CNNCifar(nn.Module): + """Convolutional Neural Network architecture for cifar dataset.""" + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 100) + self.fc3 = nn.Linear(100, 10) + + self.weight_keys = [ + ["fc1.weight", "fc1.bias"], + ["fc2.weight", "fc2.bias"], + ["fc3.weight", "fc3.bias"], + ["conv2.weight", "conv2.bias"], + ["conv1.weight", "conv1.bias"], + ] + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + out = F.log_softmax(x, dim=1) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +def create_model(model_config, model_rate=None, track=False, device="cpu"): + """Create the model based on the configuration given in hydra.""" + model = None + model_config = model_config.copy() + model_config["track"] = track + + if model_config["model"] == "MLP": + model = MLP() + model.to(device) + elif model_config["model"] == "CNNCifar": + model = CNNCifar() + model.to(device) + elif model_config["model"] == "conv": + model = conv(model_rate=model_rate, model_config=model_config, device=device) + elif model_config["model"] == "resnet18": + model = resnet18( + model_rate=model_rate, model_config=model_config, device=device + ) + return model + + +def _init_param(m_param): + if isinstance(m_param, (nn.BatchNorm2d, nn.InstanceNorm2d)): + m_param.weight.data.fill_(1) + m_param.bias.data.zero_() + elif isinstance(m_param, nn.Linear): + m_param.bias.data.zero_() + return m_param + + +class _Scaler(nn.Module): + def __init__(self, rate): + super().__init__() + self.rate = rate + + def forward(self, inp): + """Forward of Scalar nn.Module.""" + output = inp / self.rate if self.training else inp + return output + + +def get_parameters(net) -> List[np.ndarray]: + """Return the parameters of model as numpy.NDArrays.""" + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_parameters(net, parameters: List[np.ndarray]): + """Set the model parameters with given parameters.""" + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + +def train(model, train_loader, label_split, settings): + """Train a model with given settings. + + Parameters + ---------- + model : nn.Module + The neural network to train. + train_loader : DataLoader + The DataLoader containing the data to train the network on. + label_split : torch.tensor + Tensor containing the labels of the data. + settings: Dict + Dictionary conatining the information about eopchs, optimizer, + lr, momentum, weight_decay, device to train on. + """ + # criterion = torch.nn.CrossEntropyLoss() + optimizer = make_optimizer( + settings["optimizer"], + model.parameters(), + learning_rate=settings["lr"], + momentum=settings["momentum"], + weight_decay=settings["weight_decay"], + ) + + model.train() + for _ in range(settings["epochs"]): + for images, labels in train_loader: + input_dict = {} + input_dict["img"] = images.to(settings["device"]) + input_dict["label"] = labels.to(settings["device"]) + input_dict["label_split"] = label_split.type(torch.int).to( + settings["device"] + ) + optimizer.zero_grad() + output = model(input_dict) + output["loss"].backward() + if ("clip" not in settings) or ( + "clip" in settings and settings["clip"] is True + ): + torch.nn.utils.clip_grad_norm_(model.parameters(), 1) + optimizer.step() + + +def test(model, test_loader, label_split=None, device="cpu"): + """Evaluate the network on the test set. + + Parameters + ---------- + model : nn.Module + The neural network to test. + test_loader : DataLoader + The DataLoader containing the data to test the network on. + device : torch.device + The device on which the model should be tested, either 'cpu' or 'cuda'. + + Returns + ------- + Tuple[float, float] + The loss and the accuracy of the input model on the given data. + """ + model.eval() + size = len(test_loader.dataset) + num_batches = len(test_loader) + test_loss, correct = 0, 0 + + with torch.no_grad(): + model.train(False) + for images, labels in test_loader: + input_dict = {} + input_dict["img"] = images.to(device) + input_dict["label"] = labels.to(device) + if label_split is not None: + input_dict["label_split"] = label_split.type(torch.int).to(device) + output = model(input_dict) + test_loss += output["loss"].item() + correct += ( + (output["score"].argmax(1) == input_dict["label"]) + .type(torch.float) + .sum() + .item() + ) + + test_loss /= num_batches + correct /= size + return test_loss, correct + + +def param_model_rate_mapping( + model_name, parameters, clients_model_rate, global_model_rate=1 +): + """Map the model rate to subset of global parameters(as list of indices). + + Parameters + ---------- + model_name : str + The name of the neural network of global model. + parameters : Dict + state_dict of the global model. + client_model_rate : List[float] + List of model rates of active clients. + global_model_rate: float + Model rate of the global model. + + Returns + ------- + Dict + model rate to parameters indices relative to global model mapping. + """ + unique_client_model_rate = list(set(clients_model_rate)) + print(unique_client_model_rate) + + if "conv" in model_name: + idx = _mr_to_param_idx_conv( + parameters, unique_client_model_rate, global_model_rate + ) + elif "resnet" in model_name: + idx = _mr_to_param_idx_resnet18( + parameters, unique_client_model_rate, global_model_rate + ) + else: + raise ValueError("Not valid model name") + + # add model rate as key to the params calculated + param_idx_model_rate_mapping = OrderedDict() + for i, _ in enumerate(unique_client_model_rate): + param_idx_model_rate_mapping[unique_client_model_rate[i]] = idx[i] + + return param_idx_model_rate_mapping + + +def _mr_to_param_idx_conv(parameters, unique_client_model_rate, global_model_rate): + idx_i = [None for _ in range(len(unique_client_model_rate))] + idx = [OrderedDict() for _ in range(len(unique_client_model_rate))] + output_weight_name = [k for k in parameters.keys() if "weight" in k][-1] + output_bias_name = [k for k in parameters.keys() if "bias" in k][-1] + for k, val in parameters.items(): + parameter_type = k.split(".")[-1] + for index, _ in enumerate(unique_client_model_rate): + if "weight" in parameter_type or "bias" in parameter_type: + scaler_rate = unique_client_model_rate[index] / global_model_rate + _get_key_k_idx_conv( + idx, + idx_i, + { + "index": index, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + output_names={ + "output_weight_name": output_weight_name, + "output_bias_name": output_bias_name, + }, + scaler_rate=scaler_rate, + ) + else: + pass + return idx + + +def _get_key_k_idx_conv( + idx, + idx_i, + param_info, + output_names, + scaler_rate, +): + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + input_size = param_info["val"].size(1) + output_size = param_info["val"].size(0) + if idx_i[param_info["index"]] is None: + idx_i[param_info["index"]] = torch.arange( + input_size, device=param_info["val"].device + ) + input_idx_i_m = idx_i[param_info["index"]] + if param_info["k"] == output_names["output_weight_name"]: + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + ) + else: + local_output_size = int(np.ceil(output_size * (scaler_rate))) + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + )[:local_output_size] + idx[param_info["index"]][param_info["k"]] = output_idx_i_m, input_idx_i_m + idx_i[param_info["index"]] = output_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + if param_info["k"] == output_names["output_bias_name"]: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + + +def _mr_to_param_idx_resnet18(parameters, unique_client_model_rate, global_model_rate): + idx_i = [None for _ in range(len(unique_client_model_rate))] + idx = [OrderedDict() for _ in range(len(unique_client_model_rate))] + for k, val in parameters.items(): + parameter_type = k.split(".")[-1] + for index, _ in enumerate(unique_client_model_rate): + if "weight" in parameter_type or "bias" in parameter_type: + scaler_rate = unique_client_model_rate[index] / global_model_rate + _get_key_k_idx_resnet18( + idx, + idx_i, + { + "index": index, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + scaler_rate=scaler_rate, + ) + else: + pass + return idx + + +def _get_key_k_idx_resnet18( + idx, + idx_i, + param_info, + scaler_rate, +): + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + input_size = param_info["val"].size(1) + output_size = param_info["val"].size(0) + if "conv1" in param_info["k"] or "conv2" in param_info["k"]: + if idx_i[param_info["index"]] is None: + idx_i[param_info["index"]] = torch.arange( + input_size, device=param_info["val"].device + ) + input_idx_i_m = idx_i[param_info["index"]] + local_output_size = int(np.ceil(output_size * (scaler_rate))) + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + )[:local_output_size] + idx_i[param_info["index"]] = output_idx_i_m + elif "shortcut" in param_info["k"]: + input_idx_i_m = idx[param_info["index"]][ + param_info["k"].replace("shortcut", "conv1") + ][1] + output_idx_i_m = idx_i[param_info["index"]] + elif "linear" in param_info["k"]: + input_idx_i_m = idx_i[param_info["index"]] + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + ) + else: + raise ValueError("Not valid k") + idx[param_info["index"]][param_info["k"]] = (output_idx_i_m, input_idx_i_m) + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_size = param_info["val"].size(0) + if "linear" in param_info["k"]: + input_idx_i_m = torch.arange(input_size, device=param_info["val"].device) + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + + +def param_idx_to_local_params(global_parameters, client_param_idx): + """Get the local parameters from the list of param indices. + + Parameters + ---------- + global_parameters : Dict + The state_dict of global model. + client_param_idx : List + Local parameters indices with respect to global model. + + Returns + ------- + Dict + state dict of local model. + """ + local_parameters = OrderedDict() + for k, val in global_parameters.items(): + parameter_type = k.split(".")[-1] + if "weight" in parameter_type or "bias" in parameter_type: + if "weight" in parameter_type: + if val.dim() > 1: + local_parameters[k] = copy.deepcopy( + val[torch.meshgrid(client_param_idx[k])] + ) + else: + local_parameters[k] = copy.deepcopy(val[client_param_idx[k]]) + else: + local_parameters[k] = copy.deepcopy(val[client_param_idx[k]]) + else: + local_parameters[k] = copy.deepcopy(val) + return local_parameters + + +def get_state_dict_from_param(model, parameters): + """Get the state dict from model & parameters as np.NDarrays. + + Parameters + ---------- + model : nn.Module + The neural network. + parameters : np.NDarray + Parameters of the model as np.NDarrays. + + Returns + ------- + Dict + state dict of model. + """ + # Load the parameters into the model + for param_tensor, param_ndarray in zip( + model.state_dict(), parameters_to_ndarrays(parameters) + ): + model.state_dict()[param_tensor].copy_(torch.from_numpy(param_ndarray)) + # Step 3: Obtain the state_dict of the model + state_dict = model.state_dict() + return state_dict diff --git a/baselines/heterofl/heterofl/server.py b/baselines/heterofl/heterofl/server.py new file mode 100644 index 000000000000..f82db0a59fff --- /dev/null +++ b/baselines/heterofl/heterofl/server.py @@ -0,0 +1,101 @@ +"""Flower Server.""" +import time +from collections import OrderedDict +from typing import Callable, Dict, Optional, Tuple + +import torch +from flwr.common.typing import NDArrays, Scalar +from torch import nn + +from heterofl.models import test +from heterofl.utils import save_model + + +def gen_evaluate_fn( + data_loaders, + device: torch.device, + model: nn.Module, + keys, + enable_train_on_train_data: bool, +) -> Callable[ + [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]] +]: + """Generate the function for centralized evaluation. + + Parameters + ---------- + data_loaders : + A dictionary containing dataloaders for testing and + label split of each client. + device : torch.device + The device to test the model on. + model : + Model for testing. + keys : + keys of the model that it is trained on. + + Returns + ------- + Callable[ [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]] ] + The centralized evaluation function. + """ + intermediate_keys = keys + + def evaluate( + server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar] + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + # pylint: disable=unused-argument + """Use the entire test set for evaluation.""" + # if server_round % 5 != 0 and server_round < 395: + # return 1, {} + + net = model + params_dict = zip(intermediate_keys, parameters_ndarrays) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=False) + net.to(device) + + if server_round % 100 == 0: + save_model(net, f"model_after_round_{server_round}.pth") + + if enable_train_on_train_data is True: + print("start of testing") + start_time = time.time() + with torch.no_grad(): + net.train(True) + for images, labels in data_loaders["entire_trainloader"]: + input_dict = {} + input_dict["img"] = images.to(device) + input_dict["label"] = labels.to(device) + net(input_dict) + print(f"end of stat, time taken = {time.time() - start_time}") + + local_metrics = {} + local_metrics["loss"] = 0 + local_metrics["accuracy"] = 0 + for i, clnt_tstldr in enumerate(data_loaders["valloaders"]): + client_test_res = test( + net, + clnt_tstldr, + data_loaders["label_split"][i].type(torch.int), + device=device, + ) + local_metrics["loss"] += client_test_res[0] + local_metrics["accuracy"] += client_test_res[1] + + global_metrics = {} + global_metrics["loss"], global_metrics["accuracy"] = test( + net, data_loaders["testloader"], device=device + ) + + # return statistics + print(f"global accuracy = {global_metrics['accuracy']}") + print(f"local_accuracy = {local_metrics['accuracy']}") + return global_metrics["loss"], { + "global_accuracy": global_metrics["accuracy"], + "local_loss": local_metrics["loss"], + "local_accuracy": local_metrics["accuracy"], + } + + return evaluate diff --git a/baselines/heterofl/heterofl/strategy.py b/baselines/heterofl/heterofl/strategy.py new file mode 100644 index 000000000000..70dbd19594df --- /dev/null +++ b/baselines/heterofl/heterofl/strategy.py @@ -0,0 +1,467 @@ +"""Flower strategy for HeteroFL.""" +import copy +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + +import flwr as fl +import torch +from flwr.common import ( + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from torch import nn + +from heterofl.client_manager_heterofl import ClientManagerHeteroFL +from heterofl.models import ( + get_parameters, + get_state_dict_from_param, + param_idx_to_local_params, + param_model_rate_mapping, +) +from heterofl.utils import make_optimizer, make_scheduler + + +# pylint: disable=too-many-instance-attributes +class HeteroFL(fl.server.strategy.Strategy): + """HeteroFL strategy. + + Distribute subsets of a global model to clients according to their + + computational complexity and aggregate received models from clients. + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + model_name: str, + net: nn.Module, + optim_scheduler_settings: Dict, + global_model_rate: float = 1.0, + evaluate_fn=None, + fraction_fit: float = 1.0, + fraction_evaluate: float = 1.0, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + ) -> None: + super().__init__() + self.fraction_fit = fraction_fit + self.fraction_evaluate = fraction_evaluate + self.min_fit_clients = min_fit_clients + self.min_evaluate_clients = min_evaluate_clients + self.min_available_clients = min_available_clients + self.evaluate_fn = evaluate_fn + # # created client_to_model_mapping + # self.client_to_model_rate_mapping: Dict[str, ClientProxy] = {} + + self.model_name = model_name + self.net = net + self.global_model_rate = global_model_rate + # info required for configure and aggregate + # to be filled in initialize + self.local_param_model_rate: OrderedDict = OrderedDict() + # to be filled in initialize + self.active_cl_labels: List[torch.tensor] = [] + # to be filled in configure + self.active_cl_mr: OrderedDict = OrderedDict() + # required for scheduling the lr + self.optimizer = make_optimizer( + optim_scheduler_settings["optimizer"], + self.net.parameters(), + learning_rate=optim_scheduler_settings["lr"], + momentum=optim_scheduler_settings["momentum"], + weight_decay=optim_scheduler_settings["weight_decay"], + ) + self.scheduler = make_scheduler( + optim_scheduler_settings["scheduler"], + self.optimizer, + milestones=optim_scheduler_settings["milestones"], + ) + + def __repr__(self) -> str: + """Return a string representation of the HeteroFL object.""" + return "HeteroFL" + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters.""" + # self.make_client_to_model_rate_mapping(client_manager) + # net = conv(model_rate = 1) + if not isinstance(client_manager, ClientManagerHeteroFL): + raise ValueError( + "Not valid client manager, use ClientManagerHeterFL instead" + ) + clnt_mngr_heterofl: ClientManagerHeteroFL = client_manager + + ndarrays = get_parameters(self.net) + self.local_param_model_rate = param_model_rate_mapping( + self.model_name, + self.net.state_dict(), + clnt_mngr_heterofl.get_all_clients_to_model_mapping(), + self.global_model_rate, + ) + + if clnt_mngr_heterofl.client_label_split is not None: + self.active_cl_labels = clnt_mngr_heterofl.client_label_split.copy() + + return fl.common.ndarrays_to_parameters(ndarrays) + + def configure_fit( + self, + server_round: int, + parameters: Parameters, + client_manager: ClientManager, + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + print(f"in configure fit , server round no. = {server_round}") + if not isinstance(client_manager, ClientManagerHeteroFL): + raise ValueError( + "Not valid client manager, use ClientManagerHeterFL instead" + ) + clnt_mngr_heterofl: ClientManagerHeteroFL = client_manager + # Sample clients + # no need to change this + clientts_selection_config = {} + ( + clientts_selection_config["sample_size"], + clientts_selection_config["min_num_clients"], + ) = self.num_fit_clients(clnt_mngr_heterofl.num_available()) + + # for sampling we pass the criterion to select the required clients + clients = clnt_mngr_heterofl.sample( + num_clients=clientts_selection_config["sample_size"], + min_num_clients=clientts_selection_config["min_num_clients"], + ) + + # update client model rate mapping + clnt_mngr_heterofl.update(server_round) + + global_parameters = get_state_dict_from_param(self.net, parameters) + + self.active_cl_mr = OrderedDict() + + # Create custom configs + fit_configurations = [] + learning_rate = self.optimizer.param_groups[0]["lr"] + print(f"lr = {learning_rate}") + for client in clients: + model_rate = clnt_mngr_heterofl.get_client_to_model_mapping(client.cid) + client_param_idx = self.local_param_model_rate[model_rate] + local_param = param_idx_to_local_params( + global_parameters=global_parameters, client_param_idx=client_param_idx + ) + self.active_cl_mr[client.cid] = model_rate + # local param are in the form of state_dict, + # so converting them only to values of tensors + local_param_fitres = [val.cpu() for val in local_param.values()] + fit_configurations.append( + ( + client, + FitIns( + ndarrays_to_parameters(local_param_fitres), + {"lr": learning_rate}, + ), + ) + ) + + self.scheduler.step() + return fit_configurations + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using weighted average. + + Adopted from authors implementation. + """ + print("in aggregate fit") + gl_model = self.net.state_dict() + + param_idx = [] + for res in results: + param_idx.append( + copy.deepcopy( + self.local_param_model_rate[self.active_cl_mr[res[0].cid]] + ) + ) + + local_param_as_parameters = [fit_res.parameters for _, fit_res in results] + local_parameters_as_ndarrays = [ + parameters_to_ndarrays(local_param_as_parameters[i]) + for i in range(len(local_param_as_parameters)) + ] + local_parameters: List[OrderedDict] = [ + OrderedDict() for _ in range(len(local_param_as_parameters)) + ] + for i in range(len(results)): + j = 0 + for k, _ in gl_model.items(): + local_parameters[i][k] = local_parameters_as_ndarrays[i][j] + j += 1 + + if "conv" in self.model_name: + self._aggregate_conv(param_idx, local_parameters, results) + + elif "resnet" in self.model_name: + self._aggregate_resnet18(param_idx, local_parameters, results) + else: + raise ValueError("Not valid model name") + + return ndarrays_to_parameters([v for k, v in gl_model.items()]), {} + + def _aggregate_conv(self, param_idx, local_parameters, results): + gl_model = self.net.state_dict() + count = OrderedDict() + output_bias_name = [k for k in gl_model.keys() if "bias" in k][-1] + output_weight_name = [k for k in gl_model.keys() if "weight" in k][-1] + for k, val in gl_model.items(): + parameter_type = k.split(".")[-1] + count[k] = val.new_zeros(val.size(), dtype=torch.float32) + tmp_v = val.new_zeros(val.size(), dtype=torch.float32) + for clnt, _ in enumerate(local_parameters): + if "weight" in parameter_type or "bias" in parameter_type: + self._agg_layer_conv( + { + "cid": int(results[clnt][0].cid), + "param_idx": param_idx, + "local_parameters": local_parameters, + }, + { + "tmp_v": tmp_v, + "count": count, + }, + { + "clnt": clnt, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + { + "output_weight_name": output_weight_name, + "output_bias_name": output_bias_name, + }, + ) + else: + tmp_v += local_parameters[clnt][k] + count[k] += 1 + tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0]) + val[count[k] > 0] = tmp_v[count[k] > 0].to(val.dtype) + + def _agg_layer_conv( + self, + clnt_params, + tmp_v_count, + param_info, + output_names, + ): + # pi = param_info + param_idx = clnt_params["param_idx"] + clnt = param_info["clnt"] + k = param_info["k"] + tmp_v = tmp_v_count["tmp_v"] + count = tmp_v_count["count"] + + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + if k == output_names["output_weight_name"]: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = list(param_idx[clnt][k]) + param_idx[clnt][k][0] = param_idx[clnt][k][0][label_split] + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k][label_split] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + else: + if k == output_names["output_bias_name"]: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = param_idx[clnt][k][label_split] + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k][ + label_split + ] + count[k][param_idx[clnt][k]] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + + def _aggregate_resnet18(self, param_idx, local_parameters, results): + gl_model = self.net.state_dict() + count = OrderedDict() + for k, val in gl_model.items(): + parameter_type = k.split(".")[-1] + count[k] = val.new_zeros(val.size(), dtype=torch.float32) + tmp_v = val.new_zeros(val.size(), dtype=torch.float32) + for clnt, _ in enumerate(local_parameters): + if "weight" in parameter_type or "bias" in parameter_type: + self._agg_layer_resnet18( + { + "cid": int(results[clnt][0].cid), + "param_idx": param_idx, + "local_parameters": local_parameters, + }, + tmp_v, + count, + { + "clnt": clnt, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + ) + else: + tmp_v += local_parameters[clnt][k] + count[k] += 1 + tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0]) + val[count[k] > 0] = tmp_v[count[k] > 0].to(val.dtype) + + def _agg_layer_resnet18(self, clnt_params, tmp_v, count, param_info): + param_idx = clnt_params["param_idx"] + k = param_info["k"] + clnt = param_info["clnt"] + + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + if "linear" in k: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = list(param_idx[clnt][k]) + param_idx[clnt][k][0] = param_idx[clnt][k][0][label_split] + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k][label_split] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + else: + if "linear" in k: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = param_idx[clnt][k][label_split] + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k][ + label_split + ] + count[k][param_idx[clnt][k]] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + # if self.fraction_evaluate == 0.0: + # return [] + # config = {} + # evaluate_ins = EvaluateIns(parameters, config) + + # # Sample clients + # sample_size, min_num_clients = self.num_evaluation_clients( + # client_manager.num_available() + # ) + # clients = client_manager.sample( + # num_clients=sample_size, min_num_clients=min_num_clients + # ) + + # global_parameters = get_state_dict_from_param(self.net, parameters) + + # self.active_cl_mr = OrderedDict() + + # # Create custom configs + # evaluate_configurations = [] + # for idx, client in enumerate(clients): + # model_rate = client_manager.get_client_to_model_mapping(client.cid) + # client_param_idx = self.local_param_model_rate[model_rate] + # local_param = + # param_idx_to_local_params(global_parameters, client_param_idx) + # self.active_cl_mr[client.cid] = model_rate + # # local param are in the form of state_dict, + # # so converting them only to values of tensors + # local_param_fitres = [v.cpu() for v in local_param.values()] + # evaluate_configurations.append( + # (client, EvaluateIns(ndarrays_to_parameters(local_param_fitres), {})) + # ) + # return evaluate_configurations + + return [] + + # return self.configure_fit(server_round , parameters , client_manager) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using weighted average.""" + # if not results: + # return None, {} + + # loss_aggregated = weighted_loss_avg( + # [ + # (evaluate_res.num_examples, evaluate_res.loss) + # for _, evaluate_res in results + # ] + # ) + + # accuracy_aggregated = 0 + # for cp, y in results: + # print(f"{cp.cid}-->{y.metrics['accuracy']}", end=" ") + # accuracy_aggregated += y.metrics["accuracy"] + # accuracy_aggregated /= len(results) + + # metrics_aggregated = {"accuracy": accuracy_aggregated} + # print(f"\npaneer lababdar {metrics_aggregated}") + # return loss_aggregated, metrics_aggregated + + return None, {} + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function.""" + if self.evaluate_fn is None: + # No evaluation function provided + return None + parameters_ndarrays = parameters_to_ndarrays(parameters) + eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {}) + if eval_res is None: + return None + loss, metrics = eval_res + return loss, metrics + + def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + """Return sample size and required number of clients.""" + num_clients = int(num_available_clients * self.fraction_fit) + return max(num_clients, self.min_fit_clients), self.min_available_clients + + def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + """Use a fraction of available clients for evaluation.""" + num_clients = int(num_available_clients * self.fraction_evaluate) + return max(num_clients, self.min_evaluate_clients), self.min_available_clients diff --git a/baselines/heterofl/heterofl/utils.py b/baselines/heterofl/heterofl/utils.py new file mode 100644 index 000000000000..3bcb7f3d8ea7 --- /dev/null +++ b/baselines/heterofl/heterofl/utils.py @@ -0,0 +1,218 @@ +"""Contains utility functions.""" +import errno +import os +from pathlib import Path + +import numpy as np +import torch +from hydra.core.hydra_config import HydraConfig + + +def preprocess_input(cfg_model, cfg_data): + """Preprocess the input to get input shape, other derivables. + + Parameters + ---------- + cfg_model : DictConfig + Retrieve model-related information from the base.yaml configuration in Hydra. + cfg_data : DictConfig + Retrieve data-related information required to construct the model. + + Returns + ------- + Dict + Dictionary contained derived information from config. + """ + model_config = {} + # if cfg_model.model_name == "conv": + # model_config["model_name"] = + # elif for others... + model_config["model"] = cfg_model.model_name + if cfg_data.dataset_name == "MNIST": + model_config["data_shape"] = [1, 28, 28] + model_config["classes_size"] = 10 + elif cfg_data.dataset_name == "CIFAR10": + model_config["data_shape"] = [3, 32, 32] + model_config["classes_size"] = 10 + + if "hidden_layers" in cfg_model: + model_config["hidden_layers"] = cfg_model.hidden_layers + if "norm" in cfg_model: + model_config["norm"] = cfg_model.norm + if "scale" in cfg_model: + model_config["scale"] = cfg_model.scale + if "mask" in cfg_model: + model_config["mask"] = cfg_model.mask + + return model_config + + +def make_optimizer(optimizer_name, parameters, learning_rate, weight_decay, momentum): + """Make the optimizer with given config. + + Parameters + ---------- + optimizer_name : str + Name of the optimizer. + parameters : Dict + Parameters of the model. + learning_rate: float + Learning rate of the optimizer. + weight_decay: float + weight_decay of the optimizer. + + Returns + ------- + torch.optim.Optimizer + Optimizer. + """ + optimizer = None + if optimizer_name == "SGD": + optimizer = torch.optim.SGD( + parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay + ) + return optimizer + + +def make_scheduler(scheduler_name, optimizer, milestones): + """Make the scheduler with given config. + + Parameters + ---------- + scheduler_name : str + Name of the scheduler. + optimizer : torch.optim.Optimizer + Parameters of the model. + milestones: List[int] + List of epoch indices. Must be increasing. + + Returns + ------- + torch.optim.lr_scheduler.Scheduler + scheduler. + """ + scheduler = None + if scheduler_name == "MultiStepLR": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=milestones + ) + return scheduler + + +def get_global_model_rate(model_mode): + """Give the global model rate from string(cfg.control.model_mode) . + + Parameters + ---------- + model_mode : str + Contains the division of computational complexties among clients. + + Returns + ------- + str + global model computational complexity. + """ + model_mode = "" + model_mode + model_mode = model_mode.split("-")[0][0] + return model_mode + + +class ModelRateManager: + """Control the model rate of clients in case of simulation.""" + + def __init__(self, model_split_mode, model_split_rate, model_mode): + self.model_split_mode = model_split_mode + self.model_split_rate = model_split_rate + self.model_mode = model_mode + self.model_mode = self.model_mode.split("-") + + def create_model_rate_mapping(self, num_users): + """Change the client to model rate mapping accordingly.""" + client_model_rate = [] + + if self.model_split_mode == "fix": + mode_rate, proportion = [], [] + for comp_level_prop in self.model_mode: + mode_rate.append(self.model_split_rate[comp_level_prop[0]]) + proportion.append(int(comp_level_prop[1:])) + num_users_proportion = num_users // sum(proportion) + for i, comp_level in enumerate(mode_rate): + client_model_rate += np.repeat( + comp_level, num_users_proportion * proportion[i] + ).tolist() + client_model_rate = client_model_rate + [ + client_model_rate[-1] for _ in range(num_users - len(client_model_rate)) + ] + # return client_model_rate + + elif self.model_split_mode == "dynamic": + mode_rate, proportion = [], [] + + for comp_level_prop in self.model_mode: + mode_rate.append(self.model_split_rate[comp_level_prop[0]]) + proportion.append(int(comp_level_prop[1:])) + + proportion = (np.array(proportion) / sum(proportion)).tolist() + + rate_idx = torch.multinomial( + torch.tensor(proportion), num_samples=num_users, replacement=True + ).tolist() + client_model_rate = np.array(mode_rate)[rate_idx] + + # return client_model_rate + + else: + raise ValueError("Not valid model split mode") + + return client_model_rate + + +def save_model(model, path): + """To save the model in the given path.""" + # print('in save model') + current_path = HydraConfig.get().runtime.output_dir + model_save_path = Path(current_path) / path + torch.save(model.state_dict(), model_save_path) + + +# """ The following functions(check_exists, makedir_exit_ok, save, load) +# are adopted from authors (of heterofl) implementation.""" + + +def check_exists(path): + """Check if the given path exists.""" + return os.path.exists(path) + + +def makedir_exist_ok(path): + """Create a directory.""" + try: + os.makedirs(path) + except OSError as os_err: + if os_err.errno == errno.EEXIST: + pass + else: + raise + + +def save(inp, path, protocol=2, mode="torch"): + """Save the inp in a given path.""" + dirname = os.path.dirname(path) + makedir_exist_ok(dirname) + if mode == "torch": + torch.save(inp, path, pickle_protocol=protocol) + elif mode == "numpy": + np.save(path, inp, allow_pickle=True) + else: + raise ValueError("Not valid save mode") + + +# pylint: disable=no-else-return +def load(path, mode="torch"): + """Load the file from given path.""" + if mode == "torch": + return torch.load(path, map_location=lambda storage, loc: storage) + elif mode == "numpy": + return np.load(path, allow_pickle=True) + else: + raise ValueError("Not valid save mode") diff --git a/baselines/heterofl/pyproject.toml b/baselines/heterofl/pyproject.toml new file mode 100644 index 000000000000..0f72edf20345 --- /dev/null +++ b/baselines/heterofl/pyproject.toml @@ -0,0 +1,145 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.masonry.api" + +[tool.poetry] +name = "heterofl" # <----- Ensure it matches the name of your baseline directory containing all the source code +version = "1.0.0" +description = "HeteroFL : Computation And Communication Efficient Federated Learning For Heterogeneous Clients" +license = "Apache-2.0" +authors = ["M S Chaitanya Kumar ", "The Flower Authors "] +readme = "README.md" +homepage = "https://flower.dev" +repository = "https://github.com/adap/flower" +documentation = "https://flower.dev" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[tool.poetry.dependencies] +python = ">=3.10.0, <3.11.0" +flwr = { extras = ["simulation"], version = "1.5.0" } +hydra-core = "1.3.2" # don't change this +torch = { url = "https://download.pytorch.org/whl/cu118/torch-2.1.0%2Bcu118-cp310-cp310-linux_x86_64.whl"} +torchvision = { url = "https://download.pytorch.org/whl/cu118/torchvision-0.16.0%2Bcu118-cp310-cp310-linux_x86_64.whl"} +anytree = "^2.12.1" +types-six = "^1.16.21.9" +tqdm = "4.66.1" + +[tool.poetry.dev-dependencies] +isort = "==5.11.5" +black = "==23.1.0" +docformatter = "==1.5.1" +mypy = "==1.4.1" +pylint = "==2.8.2" +flake8 = "==3.9.2" +pytest = "==6.2.4" +pytest-watch = "==4.2.0" +ruff = "==0.0.272" +types-requests = "==2.27.7" +virtualenv = "20.21.0" + +[tool.isort] +line_length = 88 +indent = " " +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" +testpaths = [ + "flwr_baselines", +] + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" +signature-mutators="hydra.main.main" + + +[tool.pylint.typecheck] +generated-members="numpy.*, torch.*, tensorflow.*" + + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/baselines/hfedxgboost/.gitignore b/baselines/hfedxgboost/.gitignore new file mode 100644 index 000000000000..3d66a02ea3cc --- /dev/null +++ b/baselines/hfedxgboost/.gitignore @@ -0,0 +1,2 @@ +dataset/ +outputs/ diff --git a/baselines/hfedxgboost/LICENSE b/baselines/hfedxgboost/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/baselines/hfedxgboost/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/baselines/hfedxgboost/README.md b/baselines/hfedxgboost/README.md new file mode 100644 index 000000000000..2f31e2c4c584 --- /dev/null +++ b/baselines/hfedxgboost/README.md @@ -0,0 +1,297 @@ +--- +title: Gradient-less Federated Gradient Boosting Trees with Learnable Learning Rates +URL: https://arxiv.org/abs/2304.07537 +labels: [cross-silo, tree-based, XGBoost, Classification, Regression, Tabular] +dataset: [a9a, cod-rna, ijcnn1, space_ga, cpusmall, YearPredictionMSD] +--- + +# Gradient-less Federated Gradient Boosting Trees with Learnable Learning Rates + +> Note: If you use this baseline in your work, please remember to cite the original authors of the paper as well as the Flower paper. + +**Paper:** [arxiv.org/abs/2304.07537](https://arxiv.org/abs/2304.07537) + +**Authors:** Chenyang Ma, Xinchi Qiu, Daniel J. Beutel, Nicholas D. Lane + +**Abstract:** The privacy-sensitive nature of decentralized datasets and the robustness of eXtreme Gradient Boosting (XGBoost) on tabular data raise the need to train XGBoost in the context of federated learning (FL). Existing works on federated XGBoost in the horizontal setting rely on the sharing of gradients, which induce per-node level communication frequency and serious privacy concerns. To alleviate these problems, we develop an innovative framework for horizontal federated XGBoost which does not depend on the sharing of gradients and simultaneously boosts privacy and communication efficiency by making the learning rates of the aggregated tree ensembles are learnable. We conduct extensive evaluations on various classification and regression datasets, showing our approach achieve performance comparable to the state-of-the-art method and effectively improves communication efficiency by lowering both communication rounds and communication overhead by factors ranging from 25x to 700x. + + +## About this baseline + +**What’s implemented:** The code in this directory replicates the experiments in "Gradient-less Federated Gradient Boosting Trees with Learnable Learning Rates" (Ma et al., 2023) for a9a, cod-rna, ijcnn1, space_ga datasets, which proposed the FedXGBllr algorithm. Concretely, it replicates the results for a9a, cod-rna, ijcnn1, space_ga datasets in Table 2. + +**Datasets:** a9a, cod-rna, ijcnn1, space_ga + +**Hardware Setup:** Most of the experiments were done on a machine with an Intel® Core™ i7-6820HQ Processor, that processor got 4 cores and 8 threads. + +**Contributors:** [Aml Hassan Esmil](https://github.com/Aml-Hassan-Abd-El-hamid) + +## Experimental Setup + +**Task:** Tabular classification and regression + +**Model:** XGBoost model combined with 1-layer CNN + +**Dataset:** +This baseline only includes 7 datasets with a focus on 4 of them (a9a, cod-rna, ijcnn1, space_ga). + +Each dataset can be partitioned across 2, 5 or 10 clients in an IID distribution. + +| task type | Dataset | no.of features | no.of samples | +| :---: | :---: | :---: | :---: | +| Binary classification | a9a
cod-rna
ijcnn1 | 123
8
22 | 32,561
59,5358
49,990 | +| Regression | abalone
cpusmall
space_ga
YearPredictionMSD | 8
12
6
90 | 4,177
8,192
3,167
515,345 | + + +**Training Hyperparameters:** +For the centralized model, the paper's hyperparameters were mostly used as they give very good results -except for abalone and cpusmall-, here are the used hyperparameters -they can all be found in the `yaml` file named `paper_xgboost_centralized`: + +| Hyperparameter name | value | +| -- | -- | +| n_estimators | 500 | +| max_depth | 8 | +| subsample | 0.8 | +| learning_rate | .1 | +| colsample_bylevel | 1 | +| colsample_bynode | 1 | +| colsample_bytree | 1 | +| alpha | 5 | +| gamma | 5 | +| num_parallel_tree | 1 | +| min_child_weight | 1 | + +Here are all the original hyperparameters for the federated horizontal XGBoost model -hyperparameters that are used only in the XGBoost model are initialized with xgb same for the ones only used in Adam-: + +| Hyperparameter name | value | +| -- | -- | +| n_estimators | 500/no.of clients | +| xgb max_depth | 8 | +| xgb subsample | 0.8 | +| xgb learning_rate | .1 | +| xgb colsample_bylevel | 1 | +| xgb colsample_bynode | 1 | +| xgb colsample_bytree | 1 | +| xgb alpha | 5 | +| xgb gamma | 5 | +| xgb num_parallel_tree | 1 | +| xgb min_child_weight | 1 | +| Adam learning rate | .0001 | +| Adam Betas | 0.5, 0.999 | +| no.of iterations for the CNN model | 100 | + +Those hyperparameters did well for most datasets but for some datasets, it wasn't giving the best performance so a fine-tuning journey has started in order to achieve better results.
+At first, it was a manual process basically experiencing different values for some groups of hyperparameters to explore those hyperparameters's effect on the performance of different datasets until I decided to focus on those groups of the following hyperparameters as they seemed to have the major effect on different datasets performances: +| Hyperparameter name | +| -- | +| n_estimators | +| xgb max_depth | +| Adam learning rate | +| no.of iterations for the CNN model | + +All the final new values for those hyperparameters can be found in 3 `yaml` files named `dataset_name__clients` and all the original values for those hyperparameters can be found in 3 `yaml` files named `paper__clients`. This resulted in a large number of config files 3*7+3= 24 config files in the `clients` folder. + +## Environment Setup + +These steps assume you have already installed `Poetry` and `pyenv`. In the this directory (i.e. `/baselines/hfedxgboost`) where you can see `pyproject.toml`, execute the following commands in your terminal: + +```bash +# Set python version +pyenv local 3.10.6 +# Tell Poetry to use it +poetry env use 3.10.6 +# Install all dependencies +poetry install +# Activate your environment +poetry shell +``` + +## Running the Experiments + +With your environment activated you can run the experiments directly. The datasets will be downloaded automatically. + +```bash +# to run the experiments for the centralized model with customized hyperparameters run +python -m hfedxgboost.main --config-name Centralized_Baseline dataset= xgboost_params_centralized= +#e.g +# to run the centralized model with customized hyperparameters for cpusmall dataset +python -m hfedxgboost.main --config-name Centralized_Baseline dataset=cpusmall xgboost_params_centralized=cpusmall_xgboost_centralized + +# to run the federated version for any dataset with no.of clients +python -m hfedxgboost.main dataset= clients= +# for example +# to run the federated version for a9a dataset with 5 clients +python -m hfedxgboost.main dataset=a9a clients=a9a_5_clients + +# if you wish to change any parameters from any config file from the terminal, then you should follow this formula +python -m hfedxgboost.main folder=config_file_name folder.parameter_name=its_new_value +#e.g: +python -m hfedxgboost.main --config-name Centralized_Baseline dataset=abalone xgboost_params_centralized=abalone_xgboost_centralized xgboost_params_centralized.max_depth=8 dataset.train_ratio=.80 +``` + + +## Expected Results + +This section shows how to reproduce some of the results in the paper. Tables 2 and 3 were obtained using different hyperparameters than those indicated in the paper. Without these some experimetn exhibited worse performance. Still, some results remain far from those in the original paper. + +### Table 1: Centralized Evaluation +```bash +# to run all the experiments for the centralized model with the original paper config for all the datasets +# gives the output shown in Table 1 +python -m hfedxgboost.main --config-name centralized_basline_all_datasets_paper_config + +# Please note that unlike in the federated experiments, the results will be only printed on the terminal +# and won't be logged into a file. +``` +| Dataset | task type | test result | +| :---: | :---: | :---: | +| a9a | Binary classification | 84.9% | +| cod-rna | Binary classification | 97.3% | +| ijcnn1 | Binary classification | 98.7% | +| abalone | Regression | 4.6 | +| cpusmall | Regression | 9 | +| space_ga | Regression | .032 | +| YearPredictionMSD | Regression | 76.41 | + +### Table 2: Federated Binary Classification + +```bash +# Results for a9a dataset in table 2 +python -m hfedxgboost.main --multirun clients=a9a_2_clients,a9a_5_clients,a9a_10_clients dataset=a9a + +# Results for cod_rna dataset in table 2 +python -m hfedxgboost.main --multirun clients=cod_rna_2_clients,cod_rna_5_clients,cod_rna_10_clients dataset=cod_rna + +# Results for ijcnn1 dataset in table 2 +python -m hfedxgboost.main --multirun clients=ijcnn1_2_clients,ijcnn1_5_clients,ijcnn1_10_clients dataset=ijcnn1 +``` + +| Dataset | task type |no. of clients | server-side test Accuracy | +| :---: | :---: | :---: | :---: | +| a9a | Binary Classification | 2
5
10 | 84.4%
84.2%
83.7% | +| cod_rna | Binary Classification | 2
5
10 | 96.4%
96.2%
95.0% | +| ijcnn1 | Binary Classification |2
5
10 | 98.0%
97.28%
96.8% | + + +### Table 3: Federated Regression +```bash +# Notice that: the MSE results shown in the tables usually happen in early FL rounds (instead in the last round/s) +# Results for space_ga dataset in table 3 +python -m hfedxgboost.main --multirun clients=space_ga_2_clients,space_ga_5_clients,space_ga_10_clients dataset=space_ga + +# Results for abalone dataset in table 3 +python -m hfedxgboost.main --multirun clients=abalone_2_clients,abalone_5_clients,abalone_10_clients dataset=abalone + +# Results for cpusmall dataset in table 3 +python -m hfedxgboost.main --multirun clients=cpusmall_2_clients,cpusmall_5_clients,cpusmall_10_clients dataset=cpusmall + +# Results for YearPredictionMSD_2 dataset in table 3 +python -m hfedxgboost.main --multirun clients=YearPredictionMSD_2_clients,YearPredictionMSD_5_clients,YearPredictionMSD_10_clients dataset=YearPredictionMSD +``` + +| Dataset | task type |no. of clients | server-side test MSE | +| :---: | :---: | :---: | :---: | +| space_ga | Regression | 2
5
10 | 0.024
0.033
0.034 | +| abalone | Regression | 2
5
10 | 5.5
6.87
7.5 | +| cpusmall | Regression | 2
5
10 | 13
15.13
15.28 | +| YearPredictionMSD | Regression | 2
5
10 | 119
118
118 | + + +## Doing your own finetuning + +There are 3 main things that you should consider: + +1- You can use WandB to automate the fine-tuning process, modify the `sweep.yaml` file to control your experiments settings including your search methods, values to choose from, etc. Below we demonstrate how to run the `wandb` sweep. +If you're new to `wandb` you might want to read the following resources to [do hyperparameter tuning with W&B+PyTorch](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Organizing_Hyperparameter_Sweeps_in_PyTorch_with_W%26B.ipynb), and [use W&B alongside Hydra](https://wandb.ai/adrishd/hydra-example/reports/Configuring-W-B-Projects-with-Hydra--VmlldzoxNTA2MzQw). + +``` +# Remember to activate the poetry shell +poetry shell + +# login to your wandb account +wandb login + +# Inside the folder flower/baselines/hfedxgboost/hfedxgboost run the commands below + +# Initiate WandB sweep +wandb sweep sweep.yaml + +# that command -if ran with no error- will return a line that contains +# the command that you can use to run the sweep agent, it'll look something like that: + +wandb agent /flower-baselines_hfedxgboost_hfedxgboost/ + +``` + +2- The config files named `__clients.yaml` are meant to keep the final hyperparameters values, so whenever you think you're done with fine-tuning some hyperparameters, add them to their config files so the one after you can use them. + +3- To help with the fine-tuning of the hyperparameters process, there are 2 classes in the utils.py that write down the used hyperparameters in the experiments and the results for that experiment in 2 separate CSV files, some of the hyperparameters used in the experiments done during building this baseline can be found in results.csv and results_centralized.csv files.
+More important, those 2 classes focus on writing down only the hyperparameters that I thought was important so if you're interested in experimenting with other hyperparameters, don't forget to add them to the writers classes so you can track them more easily, especially if you intend to do some experiments away from WandB. + + +## How to add a new dataset + +This code doesn't cover all the datasets from the paper yet, so if you wish to add a new dataset, here are the steps: + +**1- you need to download the dataset from its source:** +- In the `dataset_preparation.py` file, specifically in the `download_data` function add the code to download your dataset -or if you already downloaded it manually add the code to return its file path- it could look something like the following example: +``` +if dataset_name=="": + DATASET_PATH=os.path.join(ALL_DATASETS_PATH, "") + if not os.path.exists(DATASET_PATH): + os.makedirs(DATASET_PATH) + urllib.request.urlretrieve( + "", + f"{os.path.join(DATASET_PATH, '')}", + ) + urllib.request.urlretrieve( + "", + f"{os.path.join(DATASET_PATH, '')}", + ) + # if the 2 files of your dataset are divided into training and test file put the training then test ✅ + return [os.path.join(DATASET_PATH, ''),os.path.join(DATASET_PATH, '')] +``` +that function will be called in the `dataset.py` file in the `load_single_dataset` function and the different files of your dataset will be concatenated -if your dataset is one file then nothing will happen it will just be loaded- using the `datafiles_fusion` function from the `dataset_preparation.py` file. + +:warning: if any of your dataset's files end with `.bz2` you have to add the following piece of code before the return line and inside the `if` condition +``` +for filepath in os.listdir(DATASET_PATH): + abs_filepath = os.path.join(DATASET_PATH, filepath) + with bz2.BZ2File(abs_filepath) as fr, open(abs_filepath[:-4], "wb") as fw: + shutil.copyfileobj(fr, fw) +``` + +:warning: `datafiles_fusion` function uses `sklearn.datasets.load_svmlight_file` to load the dataset, if your dataset is `csv` or something that function won't work on it and you will have to alter the `datafiles_fusion` function to work with you dataset files format. + +**2- Add config files for your dataset:** + +**a- config files for the centralized baseline:** + +- To run the centralized model on your dataset with the original hyper-parameters from the paper alongside all the other datasets added before just do the following step: + - in the dictionary called `dataset_tasks` in the `utils.py` file add your dataset name as a key -the same name that you put in the `download_data` function in the step before- and add its task type, this code performs for 2 tasks: `BINARY` which is binary classification or `REG` which is regression. + +- To run the centralized model on your dataset you need to create a config file `.yaml` in the `xgboost_params_centralized` folder and another .yaml file in the `dataset` folder -you will find that one of course inside the `conf` folder :) - and you need to specify the hyper-parameters of your choice for the xgboost model + + - the .yaml file in the `dataset` folder should look something like this: + ``` + defaults: + - task: + dataset_name: "" + train_ratio: + early_stop_patience_rounds: + ``` + - the .yaml file in the `xgboost_params_centralized` folder should contain the values for all the hyper-parameters of your choice for the xgboost model + +You can skip this whole step and use the paper default hyper-parameters from the paper, they're all written in the "paper__clients.yaml" files.
+**b- config files for the federated baseline:** + +To run the federated baseline with your dataset using your customized hyper-parameters, you need first to create the .yaml file in the `dataset` folder that was mentioned before and you need to create config files that contain the no.of the clients and it should look something like this: +``` +n_estimators_client: +num_rounds: +client_num: +num_iterations: +xgb: + max_depth: +CNN: + lr: +``` diff --git a/baselines/hfedxgboost/hfedxgboost/__init__.py b/baselines/hfedxgboost/hfedxgboost/__init__.py new file mode 100644 index 000000000000..543147a05591 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/__init__.py @@ -0,0 +1 @@ +"""hfedxgboost baseline package.""" diff --git a/baselines/hfedxgboost/hfedxgboost/client.py b/baselines/hfedxgboost/hfedxgboost/client.py new file mode 100644 index 000000000000..22435e20415b --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/client.py @@ -0,0 +1,294 @@ +"""Define your client class and a function to construct such clients. + +Please overwrite `flwr.client.NumPyClient` or `flwr.client.Client` and create a function +to instantiate your client. +""" +from typing import Any, Tuple + +import flwr as fl +import torch +from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersRes, + Status, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from hfedxgboost.models import CNN, fit_xgboost +from hfedxgboost.utils import single_tree_preds_from_each_client + + +class FlClient(fl.client.Client): + """Custom class contains the methods that the client need.""" + + def __init__( + self, + cfg: DictConfig, + trainloader: DataLoader, + valloader: DataLoader, + cid: str, + ): + self.cid = cid + self.config = cfg + + self.trainloader_original = trainloader + self.valloader_original = valloader + self.valloader: Any + + # instantiate model + self.net = CNN(cfg) + + # determine device + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def train_one_loop(self, data, optimizer, metric_fn, criterion): + """Trains the neural network model for one loop iteration. + + Parameters + ---------- + data (tuple): A tuple containing the inputs and + labels for the training data, where the input represent the predictions + of the trees from the tree ensemples of the clients. + + Returns + ------- + loss (float): The value of the loss function after the iteration. + metric_val * n_samples (float): The value of the chosen evaluation metric + (accuracy or MSE) after the iteration. + n_samples (int): The number of samples used for training in the iteration. + """ + tree_outputs, labels = data[0].to(self.device), data[1].to(self.device) + optimizer.zero_grad() + + outputs = self.net(tree_outputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # Collected training loss and accuracy statistics + n_samples = labels.size(0) + metric_val = metric_fn(outputs, labels.type(torch.int)) + + return loss.item(), metric_val * n_samples, n_samples + + def train( + self, + net: CNN, + trainloader: DataLoader, + num_iterations: int, + ) -> Tuple[float, float, int]: + """Train CNN model on a given dataset(trainloader) for(num_iterations). + + Parameters + ---------- + net (CNN): The convolutional neural network to be trained. + trainloader (DataLoader): The data loader object containing + the training dataset. + num_iterations (int): The number of iterations or batches to + be processed by the network. + + Returns + ------- + Tuple[float, float, int]: A tuple containing the average loss per sample, + the average evaluation result per sample, and the total number of training + samples processed. + + Note: + + - The training is formulated in terms of the number of updates or iterations + processed by the network. + """ + net.train() + total_loss, total_result, total_n_samples = 0.0, 0.0, 0 + + # Unusually, this training is formulated in terms of number of + # updates/iterations/batches processed + # by the network. This will be helpful later on, when partitioning the + # data across clients: resulting + # in differences between dataset sizes and hence inconsistent numbers of updates + # per 'epoch'. + optimizer = torch.optim.Adam( + self.net.parameters(), lr=self.config.clients.CNN.lr, betas=(0.5, 0.999) + ) + metric_fn = instantiate(self.config.dataset.task.metric.fn) + criterion = instantiate(self.config.dataset.task.criterion) + for _i, data in zip(range(num_iterations), trainloader): + loss, metric_val, n_samples = self.train_one_loop( + data, optimizer, metric_fn, criterion + ) + total_loss += loss + total_result += metric_val + total_n_samples += n_samples + + return ( + total_loss / total_n_samples, + total_result / total_n_samples, + total_n_samples, + ) + + def test(self, net: CNN, testloader: DataLoader) -> Tuple[float, float, int]: + """Evaluates the network on test data. + + Parameters + ---------- + net: The CNN model to be tested. + testloader: The data loader containing the test data. + + Return: A tuple containing the average loss, + average metric result, + and the total number of samples tested. + """ + total_loss, total_result, n_samples = 0.0, 0.0, 0 + net.eval() + metric_fn = instantiate(self.config.dataset.task.metric.fn) + criterion = instantiate(self.config.dataset.task.criterion) + with torch.no_grad(): + for data in testloader: + tree_outputs, labels = data[0].to(self.device), data[1].to(self.device) + outputs = net(tree_outputs) + total_loss += criterion(outputs, labels).item() + n_samples += labels.size(0) + metric_val = metric_fn(outputs.cpu(), labels.type(torch.int).cpu()) + total_result += metric_val * labels.size(0) + + return total_loss / n_samples, total_result / n_samples, n_samples + + def get_parameters(self, ins): + """Get CNN net weights and the tree. + + Parameters + ---------- + - self (object): The instance of the class that the function belongs to. + - ins (GetParametersIns): An input parameter object. + + Returns + ------- + Tuple[GetParametersRes, + Union[Tuple[XGBClassifier, int],Tuple[XGBRegressor, int]]]: + A tuple containing the parameters of the net and the tree. + - GetParametersRes: + - status : An object with the status code. + - parameters : An ndarray containing the model's weights. + - Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]: + A tuple containing either an XGBClassifier or XGBRegressor + object along with client's id. + """ + for dataset in self.trainloader_original: + data, label = dataset[0], dataset[1] + + tree = fit_xgboost( + self.config, self.config.dataset.task.task_type, data, label, 100 + ) + return GetParametersRes( + status=Status(Code.OK, ""), + parameters=ndarrays_to_parameters(self.net.get_weights()), + ), (tree, int(self.cid)) + + def fit(self, ins: FitIns) -> FitRes: + """Trains a model using the given fit parameters. + + Parameters + ---------- + ins: FitIns - The fit parameters that contain the configuration + and parameters needed for training. + + Returns + ------- + FitRes - An object that contains the status, trained parameters, + number of examples processed, and metrics. + """ + num_iterations = ins.config["num_iterations"] + batch_size = ins.config["batch_size"] + + # set parmeters + self.net.set_weights(parameters_to_ndarrays(ins.parameters[0])) # type: ignore # noqa: E501 # pylint: disable=line-too-long + aggregated_trees = ins.parameters[1] # type: ignore # noqa: E501 # pylint: disable=line-too-long + + if isinstance(aggregated_trees, list): + print("Client " + self.cid + ": recieved", len(aggregated_trees), "trees") + else: + print("Client " + self.cid + ": only had its own tree") + trainloader: Any = single_tree_preds_from_each_client( + self.trainloader_original, + batch_size, + aggregated_trees, + self.config.n_estimators_client, + self.config.clients.client_num, + ) + self.valloader = single_tree_preds_from_each_client( + self.valloader_original, + batch_size, + aggregated_trees, + self.config.n_estimators_client, + self.config.clients.client_num, + ) + + # runs for a single epoch, however many updates it may be + num_iterations = int(num_iterations) or len(trainloader) + # Train the model + print( + "Client", self.cid, ": training for", num_iterations, "iterations/updates" + ) + self.net.to(self.device) + train_loss, train_result, num_examples = self.train( + self.net, + trainloader, + num_iterations=num_iterations, + ) + print( + f"Client {self.cid}: training round complete, {num_examples}", + "examples processed", + ) + + # Return training information: model, number of examples processed and metrics + return FitRes( + status=Status(Code.OK, ""), + parameters=self.get_parameters(ins.config), + num_examples=num_examples, + metrics={ + "loss": train_loss, + self.config.dataset.task.metric.name: train_result, + }, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + """Evaluate CNN model using the given evaluation parameters. + + Parameters + ---------- + ins: An instance of EvaluateIns class that contains the parameters + for evaluation. + Return: + An EvaluateRes object that contains the evaluation results. + """ + # set the weights of the CNN net + self.net.set_weights(parameters_to_ndarrays(ins.parameters)) + + # Evaluate the model + self.net.to(self.device) + loss, result, num_examples = self.test( + self.net, + self.valloader, + ) + + # Return evaluation information + print( + f"Client {self.cid}: evaluation on {num_examples} examples:", + f"loss={loss:.4f}", + self.config.dataset.task.metric.name, + f"={result:.4f}", + ) + return EvaluateRes( + status=Status(Code.OK, ""), + loss=loss, + num_examples=num_examples, + metrics={self.config.dataset.task.metric.name: result}, + ) diff --git a/baselines/hfedxgboost/hfedxgboost/conf/Centralized_Baseline.yaml b/baselines/hfedxgboost/hfedxgboost/conf/Centralized_Baseline.yaml new file mode 100644 index 000000000000..aac3f68bbe51 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/Centralized_Baseline.yaml @@ -0,0 +1,22 @@ +centralized: True +defaults: + - dataset: abalone + - xgboost_params_centralized: abalone_xgboost_centralized + +n_estimators_client: ${xgboost_params_centralized.n_estimators} +task_type: ${dataset.task.task_type} + +XGBoost: + _target_: ${dataset.task.xgb._target_} + objective: ${dataset.task.xgb.objective} + learning_rate: ${xgboost_params_centralized.learning_rate} + max_depth: ${xgboost_params_centralized.max_depth} + n_estimators: ${xgboost_params_centralized.n_estimators} + subsample: ${xgboost_params_centralized.subsample} + colsample_bylevel: ${xgboost_params_centralized.colsample_bylevel} + colsample_bynode: ${xgboost_params_centralized.colsample_bynode} + colsample_bytree: ${xgboost_params_centralized.colsample_bytree} + alpha: ${xgboost_params_centralized.alpha} + gamma: ${xgboost_params_centralized.gamma} + num_parallel_tree: ${xgboost_params_centralized.num_parallel_tree} + min_child_weight: ${xgboost_params_centralized.min_child_weight} diff --git a/baselines/hfedxgboost/hfedxgboost/conf/base.yaml b/baselines/hfedxgboost/hfedxgboost/conf/base.yaml new file mode 100644 index 000000000000..284aab88e316 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/base.yaml @@ -0,0 +1,55 @@ +--- +defaults: + - dataset: cpusmall + - clients: cpusmall_5_clients + - wandb: default + +centralized: False +use_wandb: False +show_each_client_performance_on_its_local_data: False +val_ratio: 0.0 +batch_size: "whole" +n_estimators_client: ${clients.n_estimators_client} +task_type: ${dataset.task.task_type} +client_num: ${clients.client_num} + +XGBoost: + _target_: ${dataset.task.xgb._target_} + objective: ${dataset.task.xgb.objective} + learning_rate: .1 + max_depth: ${clients.xgb.max_depth} + n_estimators: ${clients.n_estimators_client} + subsample: 0.8 + colsample_bylevel: 1 + colsample_bynode: 1 + colsample_bytree: 1 + alpha: 5 + gamma: 5 + num_parallel_tree: 1 + min_child_weight: 1 + +server: + max_workers: None + device: "cpu" + +client_resources: + num_cpus: 1 + num_gpus: 0.0 + +strategy: + _target_: hfedxgboost.strategy.FedXgbNnAvg + _recursive_: true #everything to be instantiated + fraction_fit: 1.0 + fraction_evaluate: 0.0 # no clients will be sampled for federated evaluation (we will still perform global evaluation) + min_fit_clients: 1 + min_evaluate_clients: 1 + min_available_clients: ${client_num} + accept_failures: False + +run_experiment: + num_rounds: ${clients.num_rounds} + batch_size: 32 + fraction_fit: 1.0 + min_fit_clients: 1 + fit_config: + num_iterations: ${clients.num_iterations} diff --git a/baselines/hfedxgboost/hfedxgboost/conf/centralized_basline_all_datasets_paper_config.yaml b/baselines/hfedxgboost/hfedxgboost/conf/centralized_basline_all_datasets_paper_config.yaml new file mode 100644 index 000000000000..51d168021ac9 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/centralized_basline_all_datasets_paper_config.yaml @@ -0,0 +1,46 @@ +centralized: True +learning_rate: 0.1 +max_depth: 8 +n_estimators: 500 +subsample: 0.8 +colsample_bylevel: 1 +colsample_bynode: 1 +colsample_bytree: 1 +alpha: 5 +gamma: 5 +num_parallel_tree: 1 +min_child_weight: 1 + +dataset: + dataset_name: "all" + train_ratio: .75 + +XGBoost: + classifier: + _target_: xgboost.XGBClassifier + objective: "binary:logistic" + learning_rate: ${learning_rate} + max_depth: ${max_depth} + n_estimators: ${n_estimators} + subsample: ${subsample} + colsample_bylevel: ${colsample_bylevel} + colsample_bynode: ${colsample_bynode} + colsample_bytree: ${colsample_bytree} + alpha: ${alpha} + gamma: ${gamma} + num_parallel_tree: ${num_parallel_tree} + min_child_weight: ${min_child_weight} + regressor: + _target_: xgboost.XGBRegressor + objective: "reg:squarederror" + learning_rate: ${learning_rate} + max_depth: ${max_depth} + n_estimators: ${n_estimators} + subsample: ${subsample} + colsample_bylevel: ${colsample_bylevel} + colsample_bynode: ${colsample_bynode} + colsample_bytree: ${colsample_bytree} + alpha: ${alpha} + gamma: ${gamma} + num_parallel_tree: ${num_parallel_tree} + min_child_weight: ${min_child_weight} diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_10_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_10_clients.yaml new file mode 100644 index 000000000000..48fef3d2dd57 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_10_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 50 +num_rounds: 200 +client_num: 10 +num_iterations: 100 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_2_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_2_clients.yaml new file mode 100644 index 000000000000..d960e5ee5f40 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_2_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 250 +num_rounds: 200 +client_num: 2 +num_iterations: 100 + +xgb: + max_depth: 8 +CNN: + lr: 0.0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_5_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_5_clients.yaml new file mode 100644 index 000000000000..7e807e873b17 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/YearPredictionMSD_5_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 200 +client_num: 5 +num_iterations: 100 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_10_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_10_clients.yaml new file mode 100644 index 000000000000..4839ccb6dc91 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_10_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 30 +client_num: 10 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_2_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_2_clients.yaml new file mode 100644 index 000000000000..f38cf782a239 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_2_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 250 +num_rounds: 15 +client_num: 2 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_5_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_5_clients.yaml new file mode 100644 index 000000000000..c331db5e258a --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/a9a_5_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 30 +client_num: 5 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .0005 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_10_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_10_clients.yaml new file mode 100644 index 000000000000..055db27bef85 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_10_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 25 +num_rounds: 200 +client_num: 10 +num_iterations: 100 + +xgb: + max_depth: 6 +CNN: + lr: .0006301009302952918 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_2_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_2_clients.yaml new file mode 100644 index 000000000000..e4fecac5fb4e --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_2_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 50 +num_rounds: 100 +client_num: 2 +num_iterations: 100 + +xgb: + max_depth: 6 +CNN: + lr: 0.0028231080803766 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_5_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_5_clients.yaml new file mode 100644 index 000000000000..0610209d0b70 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/abalone_5_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 50 +num_rounds: 200 +client_num: 5 +num_iterations: 500 + +xgb: + max_depth: 6 +CNN: + lr: .0004549072000953885 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_10_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_10_clients.yaml new file mode 100644 index 000000000000..4839ccb6dc91 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_10_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 30 +client_num: 10 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_2_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_2_clients.yaml new file mode 100644 index 000000000000..9270ae839675 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_2_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 250 +num_rounds: 30 +client_num: 2 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_5_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_5_clients.yaml new file mode 100644 index 000000000000..9237b5c4362a --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/cod_rna_5_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 30 +client_num: 5 +num_iterations: 100 + +xgb: + max_depth: 8 +CNN: + lr: .001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_10_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_10_clients.yaml new file mode 100644 index 000000000000..e6882134ec84 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_10_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 50 +num_rounds: 1000 +num_iterations: 500 +client_num: 10 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_2_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_2_clients.yaml new file mode 100644 index 000000000000..bd8552875412 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_2_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 1000 +client_num: 2 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_5_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_5_clients.yaml new file mode 100644 index 000000000000..31a740714b2a --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/cpusmall_5_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 25 +num_rounds: 500 +client_num: 5 +num_iterations: 100 + +xgb: + max_depth: 6 +CNN: + lr: .000457414512764587 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_10_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_10_clients.yaml new file mode 100644 index 000000000000..cf96d4e5c394 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_10_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 30 +client_num: 10 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .0005 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_2_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_2_clients.yaml new file mode 100644 index 000000000000..69bbf5b26701 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_2_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 250 +num_rounds: 50 +client_num: 2 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_5_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_5_clients.yaml new file mode 100644 index 000000000000..945dbe885345 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/ijcnn1_5_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 30 +client_num: 5 +num_iterations: 500 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_10_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_10_clients.yaml new file mode 100644 index 000000000000..08076f993a4c --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_10_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 50 +num_rounds: 100 +client_num: 10 +num_iterations: 100 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_2_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_2_clients.yaml new file mode 100644 index 000000000000..96802df2c193 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_2_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 250 +num_rounds: 100 +client_num: 2 +num_iterations: 100 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_5_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_5_clients.yaml new file mode 100644 index 000000000000..cec270b5a52d --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/paper_5_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 100 +client_num: 5 +num_iterations: 100 + +xgb: + max_depth: 8 +CNN: + lr: .0001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_10_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_10_clients.yaml new file mode 100644 index 000000000000..ede979063464 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_10_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 50 +client_num: 10 +num_iterations: 500 + +xgb: + max_depth: 4 +CNN: + lr: .00001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_2_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_2_clients.yaml new file mode 100644 index 000000000000..a9a2b8bb38a9 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_2_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 50 +client_num: 2 +num_iterations: 500 + +xgb: + max_depth: 4 +CNN: + lr: .00001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_5_clients.yaml b/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_5_clients.yaml new file mode 100644 index 000000000000..f10de46665c1 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/clients/space_ga_5_clients.yaml @@ -0,0 +1,9 @@ +n_estimators_client: 100 +num_rounds: 200 +client_num: 5 +num_iterations: 500 + +xgb: + max_depth: 4 +CNN: + lr: .000001 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/YearPredictionMSD.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/YearPredictionMSD.yaml new file mode 100644 index 000000000000..b7d7ed14bfd6 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/YearPredictionMSD.yaml @@ -0,0 +1,5 @@ +defaults: + - task: Regression +dataset_name: "YearPredictionMSD" +train_ratio: .75 +early_stop_patience_rounds: 50 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/a9a.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/a9a.yaml new file mode 100644 index 000000000000..ded8e4fe40c7 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/a9a.yaml @@ -0,0 +1,5 @@ +defaults: + - task: Binary_Classification +dataset_name: "a9a" +train_ratio: .75 +early_stop_patience_rounds: 10 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/abalone.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/abalone.yaml new file mode 100644 index 000000000000..6aad6c65a068 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/abalone.yaml @@ -0,0 +1,5 @@ +defaults: + - task: Regression +dataset_name: "abalone" +train_ratio: .75 +early_stop_patience_rounds: 30 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/cod_rna.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/cod_rna.yaml new file mode 100644 index 000000000000..ea6e68554f80 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/cod_rna.yaml @@ -0,0 +1,5 @@ +defaults: + - task: Binary_Classification +dataset_name: "cod-rna" +train_ratio: .75 +early_stop_patience_rounds: 10 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/cpusmall.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/cpusmall.yaml new file mode 100644 index 000000000000..6aeec735b4b0 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/cpusmall.yaml @@ -0,0 +1,5 @@ +defaults: + - task: Regression +dataset_name: "cpusmall" +train_ratio: .75 +early_stop_patience_rounds: 100 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/ijcnn1.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/ijcnn1.yaml new file mode 100644 index 000000000000..0678f04d0d69 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/ijcnn1.yaml @@ -0,0 +1,5 @@ +defaults: + - task: Binary_Classification +dataset_name: "ijcnn1" +train_ratio: .75 +early_stop_patience_rounds: 10 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/space_ga.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/space_ga.yaml new file mode 100644 index 000000000000..fa89f8e852bd --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/space_ga.yaml @@ -0,0 +1,5 @@ +defaults: + - task: Regression +dataset_name: "space_ga" +train_ratio: .75 +early_stop_patience_rounds: 50 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/task/Binary_Classification.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/task/Binary_Classification.yaml new file mode 100644 index 000000000000..f4e282e181bf --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/task/Binary_Classification.yaml @@ -0,0 +1,14 @@ +task_type: "BINARY" + +metric: + name: "Accuracy" + fn: + _target_: torchmetrics.Accuracy + task: "binary" + +criterion: + _target_: torch.nn.BCELoss + +xgb: + _target_: xgboost.XGBClassifier + objective: "binary:logistic" diff --git a/baselines/hfedxgboost/hfedxgboost/conf/dataset/task/Regression.yaml b/baselines/hfedxgboost/hfedxgboost/conf/dataset/task/Regression.yaml new file mode 100644 index 000000000000..37fdca8894c3 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/dataset/task/Regression.yaml @@ -0,0 +1,14 @@ +task_type: "REG" + +metric: + name: "mse" + fn: + _target_: torchmetrics.MeanSquaredError + +criterion: + _target_: torch.nn.MSELoss + + +xgb: + _target_: xgboost.XGBRegressor + objective: "reg:squarederror" \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/wandb/default.yaml b/baselines/hfedxgboost/hfedxgboost/conf/wandb/default.yaml new file mode 100644 index 000000000000..36968201fdc6 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/wandb/default.yaml @@ -0,0 +1,6 @@ +setup: + project: p1 + mode: online +watch: + log: all + log_freq: 100 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/YearPredictionMSD_xgboost_centralized.yaml b/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/YearPredictionMSD_xgboost_centralized.yaml new file mode 100644 index 000000000000..1912c7d5a015 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/YearPredictionMSD_xgboost_centralized.yaml @@ -0,0 +1,11 @@ +n_estimators: 300 +max_depth: 4 +subsample: 0.8 +learning_rate: .1 +colsample_bylevel: 1 +colsample_bynode: 1 +colsample_bytree: 1 +alpha: 5 +gamma: 5 +num_parallel_tree: 1 +min_child_weight: 1 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/abalone_xgboost_centralized.yaml b/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/abalone_xgboost_centralized.yaml new file mode 100644 index 000000000000..72578770034d --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/abalone_xgboost_centralized.yaml @@ -0,0 +1,11 @@ +n_estimators: 200 +max_depth: 3 +subsample: .4 +learning_rate: .05 +colsample_bylevel: 1 +colsample_bynode: 1 +colsample_bytree: 1 +alpha: 5 +gamma: 10 +num_parallel_tree: 1 +min_child_weight: 5 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/cpusmall_xgboost_centralized.yaml b/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/cpusmall_xgboost_centralized.yaml new file mode 100644 index 000000000000..33983403091d --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/cpusmall_xgboost_centralized.yaml @@ -0,0 +1,11 @@ +n_estimators: 10000 +max_depth: 3 +subsample: .4 +learning_rate: .05 +colsample_bylevel: .5 +colsample_bynode: 1 +colsample_bytree: 1 +alpha: 5 +gamma: 10 +num_parallel_tree: 1 +min_child_weight: 5 diff --git a/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/paper_xgboost_centralized.yaml b/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/paper_xgboost_centralized.yaml new file mode 100644 index 000000000000..f439badb3ade --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/conf/xgboost_params_centralized/paper_xgboost_centralized.yaml @@ -0,0 +1,11 @@ +n_estimators: 500 +max_depth: 8 +subsample: 0.8 +learning_rate: .1 +colsample_bylevel: 1 +colsample_bynode: 1 +colsample_bytree: 1 +alpha: 5 +gamma: 5 +num_parallel_tree: 1 +min_child_weight: 1 \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/dataset.py b/baselines/hfedxgboost/hfedxgboost/dataset.py new file mode 100644 index 000000000000..a03ce2cd59fa --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/dataset.py @@ -0,0 +1,142 @@ +"""Handle basic dataset creation. + +In case of PyTorch it should return dataloaders for your dataset (for both the clients +and the server). If you are using a custom dataset class, this module is the place to +define it. If your dataset requires to be downloaded (and this is not done +automatically -- e.g. as it is the case for many dataset in TorchVision) and +partitioned, please include all those functions and logic in the +`dataset_preparation.py` module. You can use all those functions from functions/methods +defined here of course. +""" +from typing import List, Optional, Tuple, Union + +import torch +from flwr.common import NDArray +from torch.utils.data import DataLoader, Dataset, random_split + +from hfedxgboost.dataset_preparation import ( + datafiles_fusion, + download_data, + modify_labels, + train_test_split, +) + + +def load_single_dataset( + task_type: str, dataset_name: str, train_ratio: Optional[float] = 0.75 +) -> Tuple[NDArray, NDArray, NDArray, NDArray]: + """Load a single dataset. + + Parameters + ---------- + task_type (str): The type of task, either "BINARY" or "REG". + dataset_name (str): The name of the dataset to load. + train_ratio (float, optional): The ratio of training data to the total dataset. + Default is 0.75. + + Returns + ------- + x_train (numpy array): The training data features. + y_train (numpy array): The training data labels. + X_test (numpy array): The testing data features. + y_test (numpy array): The testing data labels. + """ + datafiles_paths = download_data(dataset_name) + X, Y = datafiles_fusion(datafiles_paths) + x_train, y_train, x_test, y_test = train_test_split(X, Y, train_ratio=train_ratio) + if task_type.upper() == "BINARY": + y_train, y_test = modify_labels(y_train, y_test) + + print( + "First class ratio in train data", + y_train[y_train == 0.0].size / x_train.shape[0], + ) + print( + "Second class ratio in train data", + y_train[y_train != 0.0].size / x_train.shape[0], + ) + print( + "First class ratio in test data", + y_test[y_test == 0.0].size / x_test.shape[0], + ) + print( + "Second class ratio in test data", + y_test[y_test != 0.0].size / x_test.shape[0], + ) + + print("Feature dimension of the dataset:", x_train.shape[1]) + print("Size of the trainset:", x_train.shape[0]) + print("Size of the testset:", x_test.shape[0]) + + return x_train, y_train, x_test, y_test + + +def get_dataloader( + dataset: Dataset, partition: str, batch_size: Union[int, str] +) -> DataLoader: + """Return a DataLoader object. + + Parameters + ---------- + dataset (Dataset): The dataset object that contains the data. + partition (str): The partition string that specifies the subset of data to use. + batch_size (Union[int, str]): The batch size to use for loading data. + It can be either an integer value or the string "whole". + If "whole" is provided, the batch size will be set to the length of the dataset. + + Returns + ------- + DataLoader: A DataLoader object that loads data from the dataset in batches. + """ + if batch_size == "whole": + batch_size = len(dataset) + return DataLoader( + dataset, batch_size=batch_size, pin_memory=True, shuffle=(partition == "train") + ) + + +def divide_dataset_between_clients( + trainset: Dataset, + testset: Dataset, + pool_size: int, + batch_size: Union[int, str], + val_ratio: float = 0.0, +) -> Tuple[DataLoader, Union[List[DataLoader], List[None]], DataLoader]: + """Divide the data between clients with IID distribution. + + Parameters + ---------- + trainset (Dataset): The full training dataset. + testset (Dataset): The full test dataset. + pool_size (int): The number of partitions to create. + batch_size (Union[int, str]): The size of the batches. + val_ratio (float, optional): The ratio of validation data. Defaults to 0.0. + + Returns + ------- + Tuple[DataLoader, DataLoader, DataLoader]: A tuple containing + the training loaders, validation loaders (or None), and test loader. + """ + # Split training set into `num_clients` partitions to simulate + # different local datasets + trainset_length = len(trainset) + lengths = [trainset_length // pool_size] * pool_size + if sum(lengths) != trainset_length: + lengths[-1] = trainset_length - sum(lengths[0:-1]) + datasets = random_split(trainset, lengths, torch.Generator().manual_seed(0)) + + # Split each partition into train/val and create DataLoader + trainloaders: List[DataLoader] = [] + valloaders: Union[List[DataLoader], List[None]] = [] + for dataset in datasets: + len_val = int(len(dataset) * val_ratio) + len_train = len(dataset) - len_val + ds_train, ds_val = random_split( + dataset, [len_train, len_val], torch.Generator().manual_seed(0) + ) + trainloaders.append(get_dataloader(ds_train, "train", batch_size)) + if len_val != 0: + valloaders.append(get_dataloader(ds_val, "val", batch_size)) + else: + valloaders.append(None) + return trainloaders, valloaders, get_dataloader(testset, "test", batch_size) diff --git a/baselines/hfedxgboost/hfedxgboost/dataset_preparation.py b/baselines/hfedxgboost/hfedxgboost/dataset_preparation.py new file mode 100644 index 000000000000..3fd3cbfb68fd --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/dataset_preparation.py @@ -0,0 +1,262 @@ +"""Handle the dataset partitioning and (optionally) complex downloads. + +Please add here all the necessary logic to either download, uncompress, pre/post-process +your dataset (or all of the above). If the desired way of running your baseline is to +first download the dataset and partition it and then run the experiments, please +uncomment the lines below and tell us in the README.md (see the "Running the Experiment" +block) that this file should be executed first. +""" +import bz2 +import os +import shutil +import urllib.request +from typing import Optional + +import numpy as np +from sklearn.datasets import load_svmlight_file + + +def download_data(dataset_name: Optional[str] = "cod-rna"): + """Download (if necessary) the dataset and returns the dataset path. + + Parameters + ---------- + dataset_name : String + A string stating the name of the dataset that need to be dowenloaded. + + Returns + ------- + List[Dataset Pathes] + The pathes for the data that will be used in train and test, + with train of full dataset in index 0 + """ + all_datasets_path = "./dataset" + if dataset_name: + dataset_path = os.path.join(all_datasets_path, dataset_name) + match dataset_name: + case "a9a": + if not os.path.exists(dataset_path): + os.makedirs(dataset_path) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/binary/a9a", + f"{os.path.join(dataset_path, 'a9a')}", + ) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/binary/a9a.t", + f"{os.path.join(dataset_path, 'a9a.t')}", + ) + # training then test ✅ + return_list = [ + os.path.join(dataset_path, "a9a"), + os.path.join(dataset_path, "a9a.t"), + ] + case "cod-rna": + if not os.path.exists(dataset_path): + os.makedirs(dataset_path) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/binary/cod-rna.t", + f"{os.path.join(dataset_path, 'cod-rna.t')}", + ) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/binary/cod-rna.r", + f"{os.path.join(dataset_path, 'cod-rna.r')}", + ) + # training then test ✅ + return_list = [ + os.path.join(dataset_path, "cod-rna.t"), + os.path.join(dataset_path, "cod-rna.r"), + ] + + case "ijcnn1": + if not os.path.exists(dataset_path): + os.makedirs(dataset_path) + + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/binary/ijcnn1.bz2", + f"{os.path.join(dataset_path, 'ijcnn1.tr.bz2')}", + ) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/binary/ijcnn1.t.bz2", + f"{os.path.join(dataset_path, 'ijcnn1.t.bz2')}", + ) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/binary/ijcnn1.tr.bz2", + f"{os.path.join(dataset_path, 'ijcnn1.tr.bz2')}", + ) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/binary/ijcnn1.val.bz2", + f"{os.path.join(dataset_path, 'ijcnn1.val.bz2')}", + ) + + for filepath in os.listdir(dataset_path): + abs_filepath = os.path.join(dataset_path, filepath) + with bz2.BZ2File(abs_filepath) as freader, open( + abs_filepath[:-4], "wb" + ) as fwriter: + shutil.copyfileobj(freader, fwriter) + # training then test ✅ + return_list = [ + os.path.join(dataset_path, "ijcnn1.t"), + os.path.join(dataset_path, "ijcnn1.tr"), + ] + + case "space_ga": + if not os.path.exists(dataset_path): + os.makedirs(dataset_path) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/regression/space_ga_scale", + f"{os.path.join(dataset_path, 'space_ga_scale')}", + ) + return_list = [os.path.join(dataset_path, "space_ga_scale")] + case "abalone": + if not os.path.exists(dataset_path): + os.makedirs(dataset_path) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/regression/abalone_scale", + f"{os.path.join(dataset_path, 'abalone_scale')}", + ) + return_list = [os.path.join(dataset_path, "abalone_scale")] + case "cpusmall": + if not os.path.exists(dataset_path): + os.makedirs(dataset_path) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/regression/cpusmall_scale", + f"{os.path.join(dataset_path, 'cpusmall_scale')}", + ) + return_list = [os.path.join(dataset_path, "cpusmall_scale")] + case "YearPredictionMSD": + if not os.path.exists(dataset_path): + print( + "long download coming -~615MB-, it'll be better if you downloaded", + "those 2 files manually with a faster download manager program or", + "something and just place them in the right folder then get", + "the for loop out of the if condition to alter their format", + ) + os.makedirs(dataset_path) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/regression/YearPredictionMSD.bz2", + f"{os.path.join(dataset_path, 'YearPredictionMSD.bz2')}", + ) + urllib.request.urlretrieve( + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" + "/regression/YearPredictionMSD.t.bz2", + f"{os.path.join(dataset_path, 'YearPredictionMSD.t.bz2')}", + ) + for filepath in os.listdir(dataset_path): + print("it will take sometime") + abs_filepath = os.path.join(dataset_path, filepath) + with bz2.BZ2File(abs_filepath) as freader, open( + abs_filepath[:-4], "wb" + ) as fwriter: + shutil.copyfileobj(freader, fwriter) + return_list = [ + os.path.join(dataset_path, "YearPredictionMSD"), + os.path.join(dataset_path, "YearPredictionMSD.t"), + ] + case _: + raise Exception("write your own dataset downloader") + return return_list + + +def datafiles_fusion(data_paths): + """Merge (if necessary) the data files and returns the features and labels. + + Parmetres: + data_paths: List[Dataset Pathes] + - The pathes for the data that will be used in train and test, + with train of full dataset in index 0 + Returns: + X: Numpy array + - The full features of the dataset. + y: Numpy array + - The full labels of the dataset. + """ + data = load_svmlight_file(data_paths[0], zero_based=False) + X = data[0].toarray() + Y = data[1] + for i in range(1, len(data_paths)): + data = load_svmlight_file( + data_paths[i], zero_based=False, n_features=X.shape[1] + ) + X = np.concatenate((X, data[0].toarray()), axis=0) + Y = np.concatenate((Y, data[1]), axis=0) + return X, Y + + +def train_test_split(X, y, train_ratio=0.75): + """Split the dataset into training and testing. + + Parameters + ---------- + X: Numpy array + The full features of the dataset. + y: Numpy array + The full labels of the dataset. + train_ratio: float + the ratio that training should take from the full dataset + + Returns + ------- + X_train: Numpy array + The training dataset features. + y_train: Numpy array + The labels of the training dataset. + X_test: Numpy array + The testing dataset features. + y_test: Numpy array + The labels of the testing dataset. + """ + np.random.seed(2023) + y = np.expand_dims(y, axis=1) + full = np.concatenate((X, y), axis=1) + np.random.shuffle(full) + y = full[:, -1] # for last column + X = full[:, :-1] # for all but last column + num_training_samples = int(X.shape[0] * train_ratio) + + x_train = X[0:num_training_samples] + y_train = y[0:num_training_samples] + + x_test = X[num_training_samples:] + y_test = y[num_training_samples:] + + x_train.flags.writeable = True + y_train.flags.writeable = True + x_test.flags.writeable = True + y_test.flags.writeable = True + + return x_train, y_train, x_test, y_test + + +def modify_labels(y_train, y_test): + """Switch the -1 in the classification dataset with 0. + + Parameters + ---------- + y_train: Numpy array + The labels of the training dataset. + y_test: Numpy array + The labels of the testing dataset. + + Returns + ------- + y_train: Numpy array + The labels of the training dataset. + y_test: Numpy array + The labels of the testing dataset. + """ + y_train[y_train == -1] = 0 + y_test[y_test == -1] = 0 + return y_train, y_test diff --git a/baselines/hfedxgboost/hfedxgboost/main.py b/baselines/hfedxgboost/hfedxgboost/main.py new file mode 100644 index 000000000000..061e635f024c --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/main.py @@ -0,0 +1,143 @@ +"""Create and connect the building blocks for your experiments; start the simulation. + +It includes processioning the dataset, instantiate strategy, specify how the global +model is going to be evaluated, etc. At the end, this script saves the results. +""" +import functools +from typing import Dict, Union + +import flwr as fl +import hydra +import torch +import wandb +from flwr.common import Scalar +from flwr.server.app import ServerConfig +from flwr.server.client_manager import SimpleClientManager +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import TensorDataset + +from hfedxgboost.client import FlClient +from hfedxgboost.dataset import divide_dataset_between_clients, load_single_dataset +from hfedxgboost.server import FlServer, serverside_eval +from hfedxgboost.utils import ( + CentralizedResultsWriter, + EarlyStop, + ResultsWriter, + create_res_csv, + local_clients_performance, + run_centralized, +) + + +@hydra.main(config_path="conf", config_name="base", version_base=None) +def main(cfg: DictConfig) -> None: + """Run the baseline. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + # 1. Print parsed config + print(OmegaConf.to_yaml(cfg)) + writer: Union[ResultsWriter, CentralizedResultsWriter] + if cfg.centralized: + if cfg.dataset.dataset_name == "all": + run_centralized(cfg, dataset_name=cfg.dataset.dataset_name) + else: + writer = CentralizedResultsWriter(cfg) + create_res_csv("results_centralized.csv", writer.fields) + writer.write_res( + "results_centralized.csv", + run_centralized(cfg, dataset_name=cfg.dataset.dataset_name)[0], + run_centralized(cfg, dataset_name=cfg.dataset.dataset_name)[1], + ) + else: + if cfg.use_wandb: + wandb.init(**cfg.wandb.setup, group=f"{cfg.dataset.dataset_name}") + + print("Dataset Name", cfg.dataset.dataset_name) + early_stopper = EarlyStop(cfg) + x_train, y_train, x_test, y_test = load_single_dataset( + cfg.dataset.task.task_type, + cfg.dataset.dataset_name, + train_ratio=cfg.dataset.train_ratio, + ) + + trainloaders, valloaders, testloader = divide_dataset_between_clients( + TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)), + TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)), + batch_size=cfg.batch_size, + pool_size=cfg.clients.client_num, + val_ratio=cfg.val_ratio, + ) + print( + f"Data partitioned across {cfg.clients.client_num} clients" + f" and {cfg.val_ratio} of local dataset reserved for validation." + ) + if cfg.show_each_client_performance_on_its_local_data: + local_clients_performance( + cfg, trainloaders, x_test, y_test, cfg.dataset.task.task_type + ) + + # Configure the strategy + def fit_config(server_round: int) -> Dict[str, Scalar]: + print(f"Configuring round {server_round}") + return { + "num_iterations": cfg.run_experiment.fit_config.num_iterations, + "batch_size": cfg.run_experiment.batch_size, + } + + # FedXgbNnAvg + strategy = instantiate( + cfg.strategy, + on_fit_config_fn=fit_config, + on_evaluate_config_fn=( + lambda r: {"batch_size": cfg.run_experiment.batch_size} + ), + evaluate_fn=functools.partial( + serverside_eval, + cfg=cfg, + testloader=testloader, + ), + ) + + print( + f"FL experiment configured for {cfg.run_experiment.num_rounds} rounds with", + f"{cfg.clients.client_num} client in the pool.", + ) + + def client_fn(cid: str) -> fl.client.Client: + """Create a federated learning client.""" + return FlClient(cfg, trainloaders[int(cid)], valloaders[int(cid)], cid) + + # Start the simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + server=FlServer( + cfg=cfg, + client_manager=SimpleClientManager(), + early_stopper=early_stopper, + strategy=strategy, + ), + num_clients=cfg.clients.client_num, + client_resources=cfg.client_resources, + config=ServerConfig(num_rounds=cfg.run_experiment.num_rounds), + strategy=strategy, + ) + + print(history) + writer = ResultsWriter(cfg) + print( + "Best Result", + writer.extract_best_res(history)[0], + "best_res_round", + writer.extract_best_res(history)[1], + ) + create_res_csv("results.csv", writer.fields) + writer.write_res("results.csv") + + +if __name__ == "__main__": + main() diff --git a/baselines/hfedxgboost/hfedxgboost/models.py b/baselines/hfedxgboost/hfedxgboost/models.py new file mode 100644 index 000000000000..fbfc2d966f69 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/models.py @@ -0,0 +1,130 @@ +"""Define our models, and training and eval functions. + +If your model is 100% off-the-shelf (e.g. directly from torchvision without requiring +modifications) you might be better off instantiating your model directly from the Hydra +config. In this way, swapping your model for another one can be done without changing +the python code at all +""" +from collections import OrderedDict +from typing import Union + +import flwr as fl +import numpy as np +import torch +import torch.nn as nn +from flwr.common import NDArray +from hydra.utils import instantiate +from omegaconf import DictConfig +from xgboost import XGBClassifier, XGBRegressor + + +def fit_xgboost( + config: DictConfig, + task_type: str, + x_train: NDArray, + y_train: NDArray, + n_estimators: int, +) -> Union[XGBClassifier, XGBRegressor]: + """Fit XGBoost model to training data. + + Parameters + ---------- + config (DictConfig): Hydra configuration. + task_type (str): Type of task, "REG" for regression and "BINARY" + for binary classification. + x_train (NDArray): Input features for training. + y_train (NDArray): Labels for training. + n_estimators (int): Number of trees to build. + + Returns + ------- + Union[XGBClassifier, XGBRegressor]: Fitted XGBoost model. + """ + if config.dataset.dataset_name == "all": + if task_type.upper() == "REG": + tree = instantiate(config.XGBoost.regressor, n_estimators=n_estimators) + elif task_type.upper() == "BINARY": + tree = instantiate(config.XGBoost.classifier, n_estimators=n_estimators) + else: + tree = instantiate(config.XGBoost) + tree.fit(x_train, y_train) + return tree + + +class CNN(nn.Module): + """CNN model.""" + + def __init__(self, cfg: DictConfig, n_channel: int = 64) -> None: + super().__init__() + n_out = 1 + self.task_type = cfg.dataset.task.task_type + n_estimators_client = cfg.n_estimators_client + client_num = cfg.client_num + + self.conv1d = nn.Conv1d( + in_channels=1, + out_channels=n_channel, + kernel_size=n_estimators_client, + stride=n_estimators_client, + padding=0, + ) + + self.layer_direct = nn.Linear(n_channel * client_num, n_out) + + self.relu = nn.ReLU() + + if self.task_type == "BINARY": + self.final_layer = nn.Sigmoid() + elif self.task_type == "REG": + self.final_layer = nn.Identity() + + # Add weight initialization + for layer in self.modules(): + if isinstance(layer, nn.Linear): + nn.init.kaiming_uniform_( + layer.weight, mode="fan_in", nonlinearity="relu" + ) + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + """Perform a forward pass. + + Parameters + ---------- + input_features (torch.Tensor): Input features to the network. + + Returns + ------- + output (torch.Tensor): Output of the network after the forward pass. + """ + output = self.conv1d(input_features) + output = output.flatten(start_dim=1) + output = self.relu(output) + output = self.layer_direct(output) + output = self.final_layer(output) + return output + + def get_weights(self) -> fl.common.NDArrays: + """Get model weights. + + Parameters + ---------- + a list of NumPy arrays. + """ + return [ + np.array(val.cpu().numpy(), copy=True) + for _, val in self.state_dict().items() + ] + + def set_weights(self, weights: fl.common.NDArrays) -> None: + """Set the CNN model weights. + + Parameters + ---------- + weights:a list of NumPy arrays + """ + layer_dict = {} + for key, value in zip(self.state_dict().keys(), weights): + if value.ndim != 0: + layer_dict[key] = torch.Tensor(np.array(value, copy=True)) + state_dict = OrderedDict(layer_dict) + self.load_state_dict(state_dict, strict=True) diff --git a/baselines/hfedxgboost/hfedxgboost/server.py b/baselines/hfedxgboost/hfedxgboost/server.py new file mode 100644 index 000000000000..e8ac6a24add8 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/server.py @@ -0,0 +1,406 @@ +"""A custom server class extends Flower's default server class to build a federated. + +learning setup that involves a combination of a CNN model and an XGBoost model, a +customized model aggregation that can work with this model combination, incorporate the +usage of an early stopping mechanism to stop training when needed and incorporate the +usage of wandb for fine-tuning purposes. +""" + +# Flower server + +import timeit +from logging import DEBUG, INFO +from typing import Dict, List, Optional, Tuple, Union + +import flwr as fl +import wandb +from flwr.common import EvaluateRes, FitRes, Parameters, Scalar, parameters_to_ndarrays +from flwr.common.logger import log +from flwr.common.typing import GetParametersIns +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.history import History +from flwr.server.server import evaluate_clients, fit_clients +from flwr.server.strategy import Strategy +from omegaconf import DictConfig +from torch.utils.data import DataLoader +from xgboost import XGBClassifier, XGBRegressor + +from hfedxgboost.models import CNN +from hfedxgboost.utils import EarlyStop, single_tree_preds_from_each_client, test + +FitResultsAndFailures = Tuple[ + List[Tuple[ClientProxy, FitRes]], + List[Union[Tuple[ClientProxy, FitRes], BaseException]], +] +EvaluateResultsAndFailures = Tuple[ + List[Tuple[ClientProxy, EvaluateRes]], + List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], +] + + +class FlServer(fl.server.Server): + """The FL_Server class is a sub-class of the fl.server.Server class. + + Attributes + ---------- + client_manager (ClientManager):responsible for managing the clients. + parameters (Parameters): The model parameters used for training + and evaluation. + strategy (Strategy): The strategy used for selecting clients + and aggregating results. + max_workers (None or int): The maximum number of workers + for parallel execution. + early_stopper (EarlyStop): The early stopper used for + determining when to stop training. + + Methods + ------- + fit_round(server_round, timeout): + Runs a round of fitting on the server side. + check_res_cen(current_round, timeout, start_time, history): + Gets results after fitting the model for the current round + and checks if the training should stop. + fit(num_rounds, timeout): + Runs federated learning for a given number of rounds. + evaluate_round(server_round, timeout): + Validates the current global model on a number of clients. + _get_initial_parameters(timeout): + Gets initial parameters from one of the available clients. + serverside_eval(server_round, parameters, config, + cfg, testloader, batch_size): + Performs server-side evaluation. + """ + + def __init__( + self, + cfg: DictConfig, + client_manager: ClientManager, + early_stopper: EarlyStop, + strategy: Strategy, + ) -> None: + super().__init__(client_manager=client_manager) + self.cfg = cfg + self._client_manager = client_manager + self.parameters = Parameters(tensors=[], tensor_type="numpy.ndarray") + self.strategy = strategy + self.early_stopper = early_stopper + + def fit_round( + self, + server_round: int, + timeout: Optional[float], + ): + """Run a round of fitting on the server side. + + Parameters + ---------- + self: The instance of the class. + server_round (int): The round of server communication. + timeout (float, optional): Maximum time to wait for client responses. + + Returns + ------- + The aggregated CNN model and tree + Aggregated metric value + A tuple containing the results and failures. + + None if no clients were selected. + """ + # Get clients and their respective instructions from strategy + client_instructions = self.strategy.configure_fit( + server_round=server_round, + parameters=self.parameters, + client_manager=self._client_manager, + ) + + if not client_instructions: + log(INFO, "fit_round %s: no clients selected, cancel", server_round) + return None + log( + DEBUG, + "fit_round %s: strategy sampled %s clients (out of %s)", + server_round, + len(client_instructions), + self._client_manager.num_available(), + ) + # Collect `fit` results from all clients participating in this round + if self.cfg.server.max_workers == "None": + max_workers = None + else: + max_workers = int(self.cfg.server.max_workers) + results, failures = fit_clients( + client_instructions=client_instructions, + max_workers=max_workers, + timeout=timeout, + ) + + log( + DEBUG, + "fit_round %s received %s results and %s failures", + server_round, + len(results), + len(failures), + ) + + # metrics_aggregated: Dict[str, Scalar] + aggregated_parm, metrics_aggregated = self.strategy.aggregate_fit( + server_round, results, failures + ) + # the tests is convinced that aggregated_parm is a Parameters | None + # which is not true as aggregated_parm is actually List[Union[Parameters,None]] + if aggregated_parm: + cnn_aggregated, trees_aggregated = aggregated_parm[0], aggregated_parm[1] # type: ignore # noqa: E501 # pylint: disable=line-too-long + else: + raise Exception("aggregated parameters is None") + + if isinstance(trees_aggregated, list): + print("Server side aggregated", len(trees_aggregated), "trees.") + else: + print("Server side did not aggregate trees.") + + return ( + (cnn_aggregated, trees_aggregated), + metrics_aggregated, + (results, failures), + ) + + def check_res_cen(self, current_round, timeout, start_time, history): + """Get results after fitting the model for the current round. + + Check if those results are not None and check + if the training should stop or not. + + Parameters + ---------- + current_round (int): The current round number. + timeout (int): The time limit for the evaluation. + start_time (float): The starting time of the evaluation. + history (History): The object for storing the evaluation history. + + Returns + ------- + bool: True if the early stop criteria is met, False otherwise. + """ + res_fit = self.fit_round(server_round=current_round, timeout=timeout) + if res_fit: + parameters_prime, _, _ = res_fit + if parameters_prime: + self.parameters = parameters_prime + res_cen = self.strategy.evaluate(current_round, parameters=self.parameters) + if res_cen is not None: + loss_cen, metrics_cen = res_cen + log( + INFO, + "fit progress: (%s, %s, %s, %s)", + current_round, + loss_cen, + metrics_cen, + timeit.default_timer() - start_time, + ) + history.add_loss_centralized(server_round=current_round, loss=loss_cen) + history.add_metrics_centralized( + server_round=current_round, metrics=metrics_cen + ) + if self.cfg.use_wandb: + wandb.log({"server_metric_value": metrics_cen, "server_loss": loss_cen}) + if self.early_stopper.early_stop(res_cen): + return True + return False + + # pylint: disable=too-many-locals + def fit(self, num_rounds: int, timeout: Optional[float]) -> History: + """Run federated learning for a given number of rounds. + + Parameters + ---------- + num_rounds (int): The number of rounds to run federated learning. + timeout (Optional[float]): The optional timeout value for each round. + + Returns + ------- + History: The history object that stores the loss and metrics data. + """ + history = History() + + # Initialize parameters + log(INFO, "Initializing global parameters") + self.parameters = self._get_initial_parameters(timeout=timeout) + + log(INFO, "Evaluating initial parameters") + res = self.strategy.evaluate(0, parameters=self.parameters) + if res is not None: + log( + INFO, + "initial parameters (loss, other metrics): %s, %s", + res[0], + res[1], + ) + history.add_loss_centralized(server_round=0, loss=res[0]) + history.add_metrics_centralized(server_round=0, metrics=res[1]) + + # Run federated learning for num_rounds + log(INFO, "FL starting") + start_time = timeit.default_timer() + for current_round in range(1, num_rounds + 1): + stop = self.check_res_cen(current_round, timeout, start_time, history) + # Evaluate model on a sample of available clients + res_fed = self.evaluate_round(server_round=current_round, timeout=timeout) + if res_fed: + loss_fed, evaluate_metrics_fed, _ = res_fed + if loss_fed: + history.add_loss_distributed( + server_round=current_round, loss=loss_fed + ) + history.add_metrics_distributed( + server_round=current_round, metrics=evaluate_metrics_fed + ) + # Stop if no progress is happening + if stop: + break + + # Bookkeeping + end_time = timeit.default_timer() + elapsed = end_time - start_time + log(INFO, "FL finished in %s", elapsed) + return history + + def evaluate_round( + self, + server_round: int, + timeout: Optional[float], + ) -> Optional[ + Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures] + ]: + """Validate current global model on a number of clients. + + Parameters + ---------- + server_round (int): representing the current server round + timeout (float, optional): The time limit for the request in seconds. + + Returns + ------- + Aggregated loss, + Aggregated metric, + Tuple of the results and failures. + or + None if no clients selected. + """ + # Get clients and their respective instructions from strategy + client_instructions = self.strategy.configure_evaluate( + server_round=server_round, + parameters=self.parameters, + client_manager=self._client_manager, + ) + if not client_instructions: + log(INFO, "evaluate_round %s: no clients selected, cancel", server_round) + return None + log( + DEBUG, + "evaluate_round %s: strategy sampled %s clients (out of %s)", + server_round, + len(client_instructions), + self._client_manager.num_available(), + ) + + # Collect `evaluate` results from all clients participating in this round + results, failures = evaluate_clients( + client_instructions, + max_workers=self.max_workers, + timeout=timeout, + ) + log( + DEBUG, + "evaluate_round %s received %s results and %s failures", + server_round, + len(results), + len(failures), + ) + + # Aggregate the evaluation results + aggregated_result = self.strategy.aggregate_evaluate( + server_round, results, failures + ) + + loss_aggregated, metrics_aggregated = aggregated_result + return loss_aggregated, metrics_aggregated, (results, failures) + + def _get_initial_parameters(self, timeout: Optional[float]): + """Get initial parameters from one of the available clients. + + Parameters + ---------- + timeout (float, optional): The time limit for the request in seconds. + Defaults to None. + + Returns + ------- + parameters (tuple): A tuple containing the initial parameters. + """ + log(INFO, "Requesting initial parameters from one random client") + random_client = self._client_manager.sample(1)[0] + ins = GetParametersIns(config={}) + get_parameters_res_tree = random_client.get_parameters(ins=ins, timeout=timeout) + log(INFO, "Received initial parameters from one random client") + + return (get_parameters_res_tree[0].parameters, get_parameters_res_tree[1]) # type: ignore # noqa: E501 # pylint: disable=line-too-long + + +def serverside_eval( + server_round: int, + parameters: Tuple[ + Parameters, + Union[ + Tuple[XGBClassifier, int], + Tuple[XGBRegressor, int], + List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]], + ], + ], + config: Dict[str, Scalar], + cfg: DictConfig, + testloader: DataLoader, +) -> Tuple[float, Dict[str, float]]: + """Perform server-side evaluation. + + Parameters + ---------- + server_round (int): The round of server evaluation. + parameters (Tuple): A tuple containing the parameters needed for evaluation. + First element: an instance of the Parameters class. + Second element: a tuple consists of either an XGBClassifier + or XGBRegressor model and an integer, or a list of that tuple. + config (Dict): A dictionary containing configuration parameters. + cfg: Hydra configuration object. + testloader (DataLoader): The data loader used for testing. + batch_size (int): The batch size used for testing. + + Returns + ------- + Tuple[float, Dict]: A tuple containing the evaluation loss (float) and + a dictionary containing the evaluation metric(s) (float). + """ + print(config, server_round) + # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + metric_name = cfg.dataset.task.metric.name + + device = cfg.server.device + model = CNN(cfg) + model.set_weights(parameters_to_ndarrays(parameters[0])) + model.to(device) + + trees_aggregated = parameters[1] + testloader = single_tree_preds_from_each_client( + testloader, + cfg.run_experiment.batch_size, + trees_aggregated, + cfg.n_estimators_client, + cfg.client_num, + ) + loss, result, _ = test(cfg, model, testloader, device=device, log_progress=False) + + print( + f"Evaluation on the server: test_loss={loss:.4f},", + f"test_,{metric_name},={result:.4f}", + ) + return loss, {metric_name: result} diff --git a/baselines/hfedxgboost/hfedxgboost/strategy.py b/baselines/hfedxgboost/hfedxgboost/strategy.py new file mode 100644 index 000000000000..eb067a89e5f0 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/strategy.py @@ -0,0 +1,79 @@ +"""Optionally define a custom strategy. + +Needed only when the strategy is not yet implemented in Flower or because you want to +extend or modify the functionality of an existing strategy. +""" +from logging import WARNING +from typing import Any, Dict, List, Optional, Tuple, Union + +from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy + +from flwr.server.strategy.aggregate import aggregate +from flwr.server.strategy import FedAvg + + +class FedXgbNnAvg(FedAvg): + """Configurable FedXgbNnAvg strategy implementation.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Federated XGBoost [Ma et al., 2023] strategy. + + Implementation based on https://arxiv.org/abs/2304.07537. + """ + super().__init__(*args, **kwargs) + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = f"FedXgbNnAvg(accept_failures={self.accept_failures})" + return rep + + def evaluate( + self, server_round: int, parameters: Any + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function.""" + if self.evaluate_fn is None: + # No evaluation function provided + return None + eval_res = self.evaluate_fn(server_round, parameters, {}) + if eval_res is None: + return None + loss, metrics = eval_res + return loss, metrics + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Any], Dict[str, Scalar]]: + """Aggregate fit results using weighted average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Convert results + weights_results = [ + ( + parameters_to_ndarrays(fit_res.parameters[0].parameters), # type: ignore # noqa: E501 # pylint: disable=line-too-long + fit_res.num_examples, + ) + for _, fit_res in results + ] + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + + # Aggregate XGBoost trees from all clients + trees_aggregated = [fit_res.parameters[1] for _, fit_res in results] # type: ignore # noqa: E501 # pylint: disable=line-too-long + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + return [parameters_aggregated, trees_aggregated], metrics_aggregated diff --git a/baselines/hfedxgboost/hfedxgboost/sweep.yaml b/baselines/hfedxgboost/hfedxgboost/sweep.yaml new file mode 100644 index 000000000000..cd2ca4b118b7 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/sweep.yaml @@ -0,0 +1,28 @@ +program: main.py +method: random +#metric part is just here for documentation +#not needed as this is random strategy +metric: + goal: minimize + name: server_loss +parameters: + use_wandb: + value: True + dataset: + value: [cpusmall] + clients: + value: cpusmall_5_clients + clients.n_estimators_client: + values: [10,25,50,100,150,200,250] + clients.num_iterations: + values: [100,500] + clients.xgb.max_depth: + values: [4,5,6,7,8] + clients.CNN.lr: + min: .00001 + max: 0.001 +command: + - ${env} + - python + - ${program} + - ${args_no_hyphens} \ No newline at end of file diff --git a/baselines/hfedxgboost/hfedxgboost/utils.py b/baselines/hfedxgboost/hfedxgboost/utils.py new file mode 100644 index 000000000000..67450f7b8af7 --- /dev/null +++ b/baselines/hfedxgboost/hfedxgboost/utils.py @@ -0,0 +1,566 @@ +"""Define any utility function. + +They are not directly relevant to the other (more FL specific) python modules. For +example, you may define here things like: loading a model from a checkpoint, saving +results, plotting. +""" +import csv +import math +import os +import os.path +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from flwr.common import NDArray +from hydra.utils import instantiate +from omegaconf import DictConfig +from sklearn.metrics import accuracy_score, mean_squared_error +from torch.utils.data import DataLoader, Dataset, TensorDataset +from tqdm import tqdm +from xgboost import XGBClassifier, XGBRegressor + +from hfedxgboost.dataset import load_single_dataset +from hfedxgboost.models import CNN, fit_xgboost + +dataset_tasks = { + "a9a": "BINARY", + "cod-rna": "BINARY", + "ijcnn1": "BINARY", + "abalone": "REG", + "cpusmall": "REG", + "space_ga": "REG", + "YearPredictionMSD": "REG", +} + + +def get_dataloader( + dataset: Dataset, partition: str, batch_size: Union[int, str] +) -> DataLoader: + """Create a DataLoader object for the given dataset. + + Parameters + ---------- + dataset (Dataset): The dataset object containing the data to be loaded. + partition (str): The partition of the dataset to load. + batch_size (Union[int, str]): The size of each mini-batch. + If "whole" is specified, the entire dataset will be included in a single batch. + + Returns + ------- + loader (DataLoader): The DataLoader object. + """ + if batch_size == "whole": + batch_size = len(dataset) + return DataLoader( + dataset, batch_size=batch_size, pin_memory=True, shuffle=(partition == "train") + ) + + +def evaluate(task_type, y, preds) -> float: + """Evaluate the performance of a given model prediction based on the task type. + + Parameters + ---------- + task_type: A string representing the type of the task + (either "BINARY" or "REG"). + y: The true target values. + preds: The predicted target values. + + Returns + ------- + result: The evaluation result based on the task type. If task_type is "BINARY", + it computes the accuracy score. + If task_type is "REG", it calculates the mean squared error. + """ + if task_type.upper() == "BINARY": + result = accuracy_score(y, preds) + elif task_type.upper() == "REG": + result = mean_squared_error(y, preds) + return result + + +def run_single_exp( + config, dataset_name, task_type, n_estimators +) -> Tuple[float, float]: + """Run a single experiment using XGBoost on a dataset. + + Parameters + ---------- + - config (object): Hydra Configuration object containing necessary settings. + - dataset_name (str): Name of the dataset to train and test on. + - task_type (str): Type of the task "BINARY" or "REG". + - n_estimators (int): Number of estimators (trees) to use in the XGBoost model. + + Returns + ------- + - result_train (float): Evaluation result on the training set. + - result_test (float): Evaluation result on the test set. + """ + x_train, y_train, x_test, y_test = load_single_dataset( + task_type, dataset_name, train_ratio=config.dataset.train_ratio + ) + tree = fit_xgboost(config, task_type, x_train, y_train, n_estimators) + preds_train = tree.predict(x_train) + result_train = evaluate(task_type, y_train, preds_train) + preds_test = tree.predict(x_test) + result_test = evaluate(task_type, y_test, preds_test) + return result_train, result_test + + +def run_centralized( + config: DictConfig, dataset_name: str = "all", task_type: Optional[str] = None +) -> Union[Tuple[float, float], List[None]]: + """Run the centralized training and testing process. + + Parameters + ---------- + config (DictConfig): Hydra configuration object. + dataset_name (str): Name of the dataset to run the experiment on. + task_type (str): Type of task. + + Returns + ------- + None: Returns None if dataset_name is "all". + Tuple: Returns a tuple (result_train, result_test) of training and testing + results if dataset_name is not "all" and task_type is specified. + + Raises + ------ + Exception: Raises an exception if task_type is not specified correctly + and the dataset_name is not in the dataset_tasks dict. + """ + if dataset_name == "all": + for dataset in dataset_tasks: + result_train, result_test = run_single_exp( + config, dataset, dataset_tasks[dataset], config.n_estimators + ) + print( + "Results for", + dataset, + ", Task:", + dataset_tasks[dataset], + ", Train:", + result_train, + ", Test:", + result_test, + ) + return [] + + if task_type: + result_train, result_test = run_single_exp( + config, + dataset_name, + task_type, + config.xgboost_params_centralized.n_estimators, + ) + print( + "Results for", + dataset_name, + ", Task:", + task_type, + ", Train:", + result_train, + ", Test:", + result_test, + ) + return result_train, result_test + + if dataset_name in dataset_tasks.keys(): + result_train, result_test = run_single_exp( + config, + dataset_name, + dataset_tasks[dataset_name], + config.xgboost_params_centralized.n_estimators, + ) + print( + "Results for", + dataset_name, + ", Task:", + dataset_tasks[dataset_name], + ", Train:", + result_train, + ", Test:", + result_test, + ) + return result_train, result_test + + raise Exception( + "task_type should be assigned to be BINARY for" + "binary classification" + "tasks or REG for regression tasks" + "or the dataset should be one of the follwing" + "a9a, cod-rna, ijcnn1, space_ga, abalone,", + "cpusmall, YearPredictionMSD", + ) + + +def local_clients_performance( + config: DictConfig, trainloaders, x_test, y_test, task_type: str +) -> None: + """Evaluate the performance of clients on local data using XGBoost. + + Parameters + ---------- + config (DictConfig): Hydra configuration object. + trainloaders: List of data loaders for each client. + x_test: Test features. + y_test: Test labels. + task_type (str): Type of prediction task. + """ + for i, trainloader in enumerate(trainloaders): + for local_dataset in trainloader: + local_x_train, local_y_train = local_dataset[0], local_dataset[1] + tree = fit_xgboost( + config, + task_type, + local_x_train, + local_y_train, + 500 // config.client_num, + ) + + preds_train = tree.predict(local_x_train) + result_train = evaluate(task_type, local_y_train, preds_train) + + preds_test = tree.predict(x_test) + result_test = evaluate(task_type, y_test, preds_test) + print("Local Client %d XGBoost Training Results: %f" % (i, result_train)) + print("Local Client %d XGBoost Testing Results: %f" % (i, result_test)) + + +def single_tree_prediction( + tree, + n_tree: int, + dataset: NDArray, +) -> Optional[NDArray]: + """Perform a single tree prediction using the provided tree object on given dataset. + + Parameters + ---------- + tree (either XGBClassifier or XGBRegressor): The tree object + used for prediction. + n_tree (int): The index of the tree to be used for prediction. + dataset (NDArray): The dataset for which the prediction is to be made. + + Returns + ------- + NDArray object: representing the prediction result. + None: If the provided n_tree is larger than the total number of trees + in the tree object, and a warning message is printed. + """ + num_t = len(tree.get_booster().get_dump()) + if n_tree > num_t: + print( + "The tree index to be extracted is larger than the total number of trees." + ) + return None + + return tree.predict( + dataset, iteration_range=(n_tree, n_tree + 1), output_margin=True + ) + + +def single_tree_preds_from_each_client( + trainloader: DataLoader, + batch_size, + client_tree_ensamples: Union[ + Tuple[XGBClassifier, int], + Tuple[XGBRegressor, int], + List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]], + ], + n_estimators_client: int, + client_num: int, +) -> Optional[Tuple[NDArray, NDArray]]: + """Predict using trees from client tree ensamples. + + Extract each tree from each tree ensample from each client, + and predict the output of the data using that tree, + place those predictions in the preds_from_all_trees_from_all_clients, + and return it. + + Parameters + ---------- + trainloader: + - a dataloader that contains the dataset to be predicted. + client_tree_ensamples: + - the trained XGBoost tree ensample from each client, + each tree ensembles comes attached + to its client id in a tuple + - can come as a single tuple of XGBoost tree ensample and + its client id or multiple tuples in one list. + + Returns + ------- + loader (DataLoader): The DataLoader object that contains the + predictions of the tree + """ + if trainloader is None: + return None + + for local_dataset in trainloader: + x_train, y_train = local_dataset[0], np.float32(local_dataset[1]) + + preds_from_all_trees_from_all_clients = np.zeros( + (x_train.shape[0], client_num * n_estimators_client), dtype=np.float32 + ) + if isinstance(client_tree_ensamples, list) is False: + temp_trees = [client_tree_ensamples[0]] * client_num + elif isinstance(client_tree_ensamples, list): + client_tree_ensamples.sort(key=lambda x: x[1]) + temp_trees = [i[0] for i in client_tree_ensamples] + if len(client_tree_ensamples) != client_num: + temp_trees += [client_tree_ensamples[0][0]] * ( + client_num - len(client_tree_ensamples) + ) + + for i, _ in enumerate(temp_trees): + for j in range(n_estimators_client): + preds_from_all_trees_from_all_clients[ + :, i * n_estimators_client + j + ] = single_tree_prediction(temp_trees[i], j, x_train) + + preds_from_all_trees_from_all_clients = torch.from_numpy( + np.expand_dims(preds_from_all_trees_from_all_clients, axis=1) + ) + y_train = torch.from_numpy(np.expand_dims(y_train, axis=-1)) + tree_dataset = TensorDataset(preds_from_all_trees_from_all_clients, y_train) + return get_dataloader(tree_dataset, "tree", batch_size) + + +def test( + cfg, + net: CNN, + testloader: DataLoader, + device: torch.device, + log_progress: bool = True, +) -> Tuple[float, float, int]: + """Evaluates the performance of a CNN model on a given test dataset. + + Parameters + ---------- + cfg (Any): The configuration object. + net (CNN): The CNN model to test. + testloader (DataLoader): The data loader for the test dataset. + device (torch.device): The device to run the evaluation on. + log_progress (bool, optional): Whether to log the progress during evaluation. + Default is True. + + Returns + ------- + Tuple[float, float, int]: A tuple containing the average loss, + average metric result, and total number of samples evaluated. + """ + criterion = instantiate(cfg.dataset.task.criterion) + metric_fn = instantiate(cfg.dataset.task.metric.fn) + + total_loss, total_result, n_samples = 0.0, 0.0, 0 + net.eval() + with torch.no_grad(): + # progress_bar = tqdm(testloader, desc="TEST") if log_progress else testloader + for data in tqdm(testloader, desc="TEST") if log_progress else testloader: + tree_outputs, labels = data[0].to(device), data[1].to(device) + outputs = net(tree_outputs) + total_loss += criterion(outputs, labels).item() + n_samples += labels.size(0) + metric_val = metric_fn(outputs.cpu(), labels.type(torch.int).cpu()) + total_result += metric_val * labels.size(0) + + if log_progress: + print("\n") + + return total_loss / n_samples, total_result / n_samples, n_samples + + +class EarlyStop: + """Stop the tain when no progress is happening.""" + + def __init__(self, cfg): + self.num_waiting_rounds = cfg.dataset.early_stop_patience_rounds + self.counter = 0 + self.min_loss = float("inf") + self.metric_value = None + + def early_stop(self, res) -> Optional[Tuple[float, float]]: + """Check if the model made any progress in number of rounds. + + If it didn't it will return the best result and the server + will stop running the fit function, if + it did it will return None, and won't stop the server. + + Parameters + ---------- + res: tuple of 2 elements, res[0] is a float that indicate the loss, + res[1] is actually a 1 element dictionary that looks like this + {'Accuracy': tensor(0.8405)} + + Returns + ------- + Optional[Tuple[float,float]]: (best loss the model achieved, + best metric value associated with that loss) + """ + loss = res[0] + metric_val = list(res[1].values())[0].item() + if loss < self.min_loss: + self.min_loss = loss + self.metric_value = metric_val + self.counter = 0 + print( + "New best loss value achieved,", + "loss", + self.min_loss, + "metric value", + self.metric_value, + ) + elif loss > (self.min_loss): + self.counter += 1 + if self.counter >= self.num_waiting_rounds: + print( + "That training is been stopped as the", + "model achieve no progress with", + "loss =", + self.min_loss, + "result =", + self.metric_value, + ) + return (self.metric_value, self.min_loss) + return None + + +# results + + +def create_res_csv(filename, fields) -> None: + """Create a CSV file with the provided file name.""" + if not os.path.isfile(filename): + with open(filename, "w") as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow(fields) + + +class ResultsWriter: + """Write the results for the federated experiments.""" + + def __init__(self, cfg) -> None: + self.cfg = cfg + self.tas_type = cfg.dataset.task.task_type + if self.tas_type == "REG": + self.best_res = math.inf + self.compare_fn = min + if self.tas_type == "BINARY": + self.best_res = -1 + self.compare_fn = max + self.best_res_round_num = 0 + self.fields = [ + "dataset_name", + "client_num", + "n_estimators_client", + "num_rounds", + "xgb_max_depth", + "cnn_lr", + "best_res", + "best_res_round_num", + "num_iterations", + ] + + def extract_best_res(self, history) -> Tuple[float, int]: + """Take the history & returns the best result and its corresponding round num. + + Parameters + ---------- + history: a history object that contains metrics_centralized keys + + Returns + ------- + Tuple[float, int]: a tuple containing the best result (float) and + its corresponding round number (int) + """ + for key in history.metrics_centralized.keys(): + for i in history.metrics_centralized[key]: + if ( + self.compare_fn(i[1].item(), self.best_res) == i[1] + and i[1].item() != self.best_res + ): + self.best_res = i[1].item() + self.best_res_round_num = i[0] + return (self.best_res, self.best_res_round_num) + + def write_res(self, filename) -> None: + """Write the results of the federated model to a CSV file. + + The function opens the specified file in 'a' (append) mode and creates a + csvwriter object and add the dataset name, xgboost model's and CNN model's + hyper-parameters used, and the result. + + Parameters + ---------- + filename: string that indicates the CSV file that will be written in. + """ + row = [ + str(self.cfg.dataset.dataset_name), + str(self.cfg.client_num), + str(self.cfg.clients.n_estimators_client), + str(self.cfg.run_experiment.num_rounds), + str(self.cfg.clients.xgb.max_depth), + str(self.cfg.clients.CNN.lr), + str(self.best_res), + str(self.best_res_round_num), + str(self.cfg.run_experiment.fit_config.num_iterations), + ] + with open(filename, "a") as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow(row) + + +class CentralizedResultsWriter: + """Write the results for the centralized experiments.""" + + def __init__(self, cfg) -> None: + self.cfg = cfg + self.tas_type = cfg.dataset.task.task_type + self.fields = [ + "dataset_name", + "n_estimators_client", + "xgb_max_depth", + "subsample", + "learning_rate", + "colsample_bylevel", + "colsample_bynode", + "colsample_bytree", + "alpha", + "gamma", + "num_parallel_tree", + "min_child_weight", + "result_train", + "result_test", + ] + + def write_res(self, filename, result_train, result_test) -> None: + """Write the results of the centralized model to a CSV file. + + The function opens the specified file in 'a' (append) mode and creates a + csvwriter object and add the dataset name, xgboost's + hyper-parameters used, and the result. + + Parameters + ---------- + filename: string that indicates the CSV file that will be written in. + """ + row = [ + str(self.cfg.dataset.dataset_name), + str(self.cfg.xgboost_params_centralized.n_estimators), + str(self.cfg.xgboost_params_centralized.max_depth), + str(self.cfg.xgboost_params_centralized.subsample), + str(self.cfg.xgboost_params_centralized.learning_rate), + str(self.cfg.xgboost_params_centralized.colsample_bylevel), + str(self.cfg.xgboost_params_centralized.colsample_bynode), + str(self.cfg.xgboost_params_centralized.colsample_bytree), + str(self.cfg.xgboost_params_centralized.alpha), + str(self.cfg.xgboost_params_centralized.gamma), + str(self.cfg.xgboost_params_centralized.num_parallel_tree), + str(self.cfg.xgboost_params_centralized.min_child_weight), + str(result_train), + str(result_test), + ] + with open(filename, "a") as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow(row) diff --git a/baselines/hfedxgboost/pyproject.toml b/baselines/hfedxgboost/pyproject.toml new file mode 100644 index 000000000000..1fbbb85ac36b --- /dev/null +++ b/baselines/hfedxgboost/pyproject.toml @@ -0,0 +1,145 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.masonry.api" + +[tool.poetry] +name = "hfedxgboost" # <----- Ensure it matches the name of your baseline directory containing all the source code +version = "1.0.0" +description = "HFedXgboost: Gradient-less Federated Gradient Boosting Trees with Learnable Learning Rates" +license = "Apache-2.0" +authors = ["The Flower Authors ", "Aml Hassan Esmil "] +readme = "README.md" +homepage = "https://flower.dev" +repository = "https://github.com/adap/flower" +documentation = "https://flower.dev" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[tool.poetry.dependencies] +python = ">=3.10.0, <3.11.0" # don't change this +flwr = { extras = ["simulation"], version = "1.5.0" } +hydra-core = "1.3.2" # don't change this +torch = "1.13.1" +scikit-learn = "1.3.0" +xgboost = "2.0.0" +torchmetrics = "1.1.2" +tqdm = "4.66.1" +torchvision = "0.14.1" +wandb = "0.15.12" + +[tool.poetry.dev-dependencies] +isort = "==5.11.5" +black = "==23.1.0" +docformatter = "==1.5.1" +mypy = "==1.4.1" +pylint = "==2.8.2" +flake8 = "==3.9.2" +pytest = "==6.2.4" +pytest-watch = "==4.2.0" +ruff = "==0.0.272" +types-requests = "==2.27.7" +virtualenv = "^20.21.0" + +[tool.isort] +line_length = 88 +indent = " " +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" +testpaths = [ + "flwr_baselines", +] + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" +signature-mutators="hydra.main.main" + +[tool.pylint.typecheck] +generated-members="numpy.*, torch.*, tensorflow.*" + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/baselines/hfedxgboost/results.csv b/baselines/hfedxgboost/results.csv new file mode 100644 index 000000000000..678421bec48e --- /dev/null +++ b/baselines/hfedxgboost/results.csv @@ -0,0 +1,155 @@ +dataset_name,client_num,n_estimators_client,num_rounds,xgb_max_depth,CNN_lr,best_res,best_res_round_num,num_iterations +a9a,2,250,30,8,0.0001,0.8429285287857056,30,100 +a9a,2,250,40,8,0.0001,0.8376873135566711,40,100 +a9a,2,250,20,8,0.0001,0.7641471028327942,20,100 +a9a,2,500,30,8,0.001,0.8455491065979004,29,100 +ijcnn1,2,250,30,8,0.0005,0.9234752058982849,30,100 +ijcnn1,2,300,50,5,0.0005,0.9468683004379272,49,100 +cod-rna,10,100,30,8,0.001,0.9503719806671143,30,100 +ijcnn1,10,100,30,8,0.0005,0.938660204410553,28,100 +ijcnn1,10,100,30,8,0.0005,0.9379972219467163,28,100 +a9a,10,100,30,8,0.001,0.8403897881507874,26,100 +cod-rna,2,250,30,8,0.0001,0.9640399813652039,30,100 +cod-rna,5,100,30,8,0.001,0.9620075225830078,29,100 +abalone,5,100,30,8,0.0001,10.37040901184082,29,100 +abalone,5,100,30,4,0.0001,10.158971786499023,28,100 +abalone,5,100,30,4,0.001,10.370624542236328,24,100 +abalone,5,50,30,2,0.01,10.314363479614258,27,100 +abalone,5,500,30,2,0.01,59.639190673828125,30,100 +abalone,5,500,50,8,0.0001,10.406515121459961,39,100 +abalone,5,100,50,4,0.0001,10.35323715209961,49,100 +abalone,5,100,50,4,0.0001,10.39498233795166,49,100 +ijcnn1,2,250,50,8,0.0001,0.9802058339118958,45,500 +a9a,2,250,30,8,0.0001,0.8449758291244507,23,500 +ijcnn1,5,100,30,8,0.0001,0.9731026887893677,16,500 +ijcnn1,10,100,30,8,0.0005,0.967893660068512,23,500 +abalone,10,100,30,8,0.0005,10.410109519958496,29,500 +cpusmall,2,250,50,8,0.0001,292.26702880859375,49,500 +cpusmall,2,250,2,4,0.0001,321.2448425292969,1,500 +cpusmall,2,100,50,4,0.0001,263.4275207519531,50,500 +cpusmall,2,100,150,4,0.0001,93.43792724609375,149,500 +cpusmall,2,100,1000,4,0.0001,14.8245849609375,810,500 +space_ga,2,100,50,4,1e-05,0.025414548814296722,4,500 +space_ga,5,100,40,4,1e-06,0.060075871646404266,40,100 +space_ga,5,100,100,4,1e-06,0.03812913969159126,0,100 +space_ga,5,100,200,4,1e-06,0.04440382868051529,0,500 +space_ga,10,50,50,4,1e-05,0.23313003778457642,0,500 +space_ga,10,100,50,4,1e-05,0.025433368980884552,22,100 +a9a,2,250,30,8,0.0001,0.8328556418418884,26,100 +cpusmall,10,25,300,4,1e-05,289.2887268066406,299,50 +cpusmall,10,25,300,4,0.0001,171.87559509277344,300,100 +cpusmall,10,25,1000,4,0.0001,21.770994186401367,994,100 +YearPredictionMSD,2,250,20,8,0.0001,119.265510559082,20,100 +cpusmall,2,100,1000,4,0.0001,15.070291519165,587,500 +cpusmall,2,100,800,8,0.0001,11.7772,679,500 +a9a,2,250,30,8,0.0001,0.8457129001617432,20,500 +a9a,2,250,20,8,0.0001,0.8449758291244507,15,500 +a9a,2,250,15,8,0.0001,0.8448120355606079,9,500 +a9a,2,250,15,8,0.0001,0.8441569209098816,15,500 +a9a,2,250,15,8,0.0001,0.8456310033798218,14,500 +a9a,2,250,15,8,0.0001,0.8412087559700012,15,500 +a9a,5,100,30,8,0.0005,0.8429285287857056,30,500 +a9a,10,100,30,8,0.001,0.8362951278686523,14,100 +a9a,10,100,30,8,0.001,0.8377692103385925,15,100 +a9a,10,100,30,8,0.001,0.837441623210907,24,500 +cod-rna,2,250,30,8,0.0001,0.965569019317627,27,500 +cod-rna,5,100,30,8,0.001,0.9636204242706299,24,500 +ijcnn1,2,250,50,8,0.0001,0.9803005456924438,45,500 +abalone,10,50,30,8,0.0001,10.225870132446289,11,100 +abalone,2,250,30,8,0.0001,10.415020942687988,30,100 +cpusmall,2,500,400,8,0.0001,99999999999.999,0,500 +cpusmall,5,100,1,4,5e-05,3044.56689453125,1,500 +cpusmall,5,100,1,4,5e-05,3168.800537109375,1,500 +abalone,2,250,30,4,0.0001,10.305878639221191,27,100 +abalone,2,150,50,8,0.0001,10.236144065856934,50,100 +abalone,2,150,50,8,0.0001,10.031058311462402,50,300 +a9a,2,250,15,8,0.0001,0.8440750241279602,13,500 +a9a,5,100,30,8,0.0005,0.8422733545303345,21,500 +a9a,10,100,30,8,0.001,0.8367046117782593,22,500 +cod-rna,2,250,30,8,0.0001,0.9654850959777832,28,500 +cod-rna,5,50,1000,8,0.0001,-1,0,500 +cod-rna,5,100,30,8,0.001,0.9605903625488281,29,100 +abalone,5,100,50,4,0.03346913907439582,10.539855003356934,17,100 +abalone,5,100,50,4,0.034568059419072865,10.53787899017334,17,100 +abalone,5,100,50,4,0.04630995211462259,10.401963233947754,3,100 +abalone,5,100,50,4,0.05173366161706037,10.5382719039917,12,100 +abalone,5,100,50,4,0.08622379813597802,10.538084983825684,19,100 +abalone,5,100,50,4,0.05839912417408451,10.541743278503418,14,100 +abalone,5,100,50,4,0.029151079702933427,10.53785514831543,20,100 +abalone,5,100,50,4,0.0006444558899121101,10.226868629455566,49,100 +abalone,5,100,50,4,0.056951124793294075,10.537845611572266,11,100 +abalone,5,100,50,4,0.05010996385142899,10.539204597473145,22,100 +abalone,5,50,50,4,0.019894781508274385,10.012669563293457,42,100 +abalone,5,100,500,4,0.0001,9.894426345825195,120,500 +abalone,5,100,500,4,0.0001,10.059885025024414,75,500 +abalone,5,100,50,4,0.0001,9.870953559875488,49,100 +abalone,5,300,50,4,0.0001,10.408512115478516,45,100 +abalone,5,200,50,4,0.0001,10.455050468444824,47,100 +abalone,5,300,50,4,0.0001,10.496016502380371,46,100 +abalone,5,300,50,4,0.0001,10.372100830078125,9,100 +abalone,5,300,50,8,0.0017554930858480753,10.486079216003418,19,500 +abalone,5,200,50,8,0.045167641025122586,10.540440559387207,13,100 +ijcnn1,2,250,50,8,0.0001,0.9805530905723572,50,500 +abalone,2,650,100,4,0.001217139626212002,10.490303993225098,10,500 +abalone,2,150,100,6,0.009503176921678468,7.928699016571045,98,500 +abalone,2,350,100,8,0.00783944676415376,9.897452354431152,79,100 +abalone,2,200,100,8,0.008207079744179863,10.53787612915039,28,100 +abalone,2,650,100,4,0.001217139626212002,10.463903427124023,23,500 +abalone,2,150,100,6,0.009503176921678468,7.670563220977783,100,500 +abalone,2,350,100,8,0.00783944676415376,10.538168907165527,29,100 +abalone,2,150,100,5,0.009531130088689884,8.326595306396484,97,500 +abalone,2,200,100,6,0.0054653156736999805,8.855677604675293,98,500 +abalone,2,50,100,5,0.0017955547171898984,6.178965091705322,96,500 +abalone,2,150,100,5,0.004387414920659405,8.143877983093262,100,500 +abalone,2,25,100,5,0.00950017009848373,9.77553939819336,27,500 +abalone,2,50,100,5,0.004882675431964606,5.8692779541015625,96,500 +abalone,2,25,100,5,0.007557058187901093,6.141476154327393,100,100 +abalone,2,10,100,7,0.005130176232078475,7.082235813140869,92,100 +abalone,2,50,100,6,0.002823108080376602,5.545757293701172,95,100 +abalone,2,10,100,6,0.0028521611812793155,7.876493453979492,98,100 +abalone,5,25,100,5,0.009785294236382034,9.208296775817871,99,500 +abalone,5,100,100,6,0.007079031827661472,10.32761001586914,1,500 +abalone,5,10,100,7,0.001344804317317952,8.094303131103516,100,500 +abalone,5,25,100,5,0.004714981058473264,9.583475112915039,96,500 +abalone,5,100,100,7,0.006981291837612061,9.868386268615723,100,100 +abalone,5,10,100,7,0.004135357042596778,9.329564094543457,98,500 +abalone,5,100,100,7,0.00362985132837509,10.195948600769043,45,500 +abalone,5,50,100,6,0.009059669650367688,10.281177520751953,7,500 +abalone,5,25,100,7,0.005573120595947365,10.041385650634766,2,500 +abalone,5,25,100,7,0.00028621587028303236,9.260563850402832,100,500 +abalone,5,25,100,8,1.5711330591445135e-05,10.278858184814453,34,500 +abalone,5,5,100,8,0.00012332785174116285,9.62313175201416,10,500 +abalone,5,25,100,4,0.0006278898956762251,8.320752143859863,97,100 +abalone,5,25,100,6,0.0003372569452123535,8.360357284545898,99,500 +abalone,5,5,100,4,0.0002038313592678424,9.912229537963867,10,100 +abalone,5,5,100,7,0.00041985945288375667,9.698519706726074,19,500 +abalone,5,25,100,6,0.0007804153506712646,7.4677205085754395,100,500 +abalone,5,5,100,4,0.0005335030582873437,9.64600944519043,5,100 +abalone,5,25,100,6,0.0009397031028496506,7.2984700202941895,100,500 +abalone,5,25,100,8,0.0002193902512638236,9.670918464660645,99,100 +abalone,5,5,100,6,0.0005419520842675534,9.692630767822266,3,100 +abalone,5,25,100,4,0.0007623175785101614,8.790078163146973,99,500 +abalone,10,10,100,6,0.0028231080803766,9.8931884765625,10,100 +abalone,10,10,100,6,0.0028231080803766,9.223672866821289,100,100 +abalone,10,10,200,6,0.0028231080803766,9.795281410217285,64,100 +abalone,5,50,200,6,0.0004549072000953885,6.875454425811768,200,500 +cpusmall,5,25,500,6,0.000457414512764587,14.813268661499023,407,100 +cpusmall,5,25,500,6,0.0003909693357440312,15.414558410644531,292,500 +cpusmall,5,25,400,6,0.000457414512764587,15.857512474060059,366,100 +cpusmall,5,25,450,6,0.000457414512764587,15.76364803314209,208,100 +cpusmall,5,25,500,6,0.000457414512764587,15.136177062988281,234,100 +space_ga,2,10,20,8,1e-05,0.025517474859952927,9,100 +space_ga,2,10,20,8,1e-05,0.02540592849254608,10,100 +space_ga,2,10,20,8,1e-05,0.06827209889888763,20,100 +space_ga,2,10,20,8,1e-05,0.025553066283464432,1,100 +space_ga,2,10,20,8,1e-05,0.02541356533765793,5,100 +space_ga,2,10,20,8,1e-05,0.2901467978954315,20,100 +space_ga,2,10,20,8,1e-05,0.04546191915869713,20,100 +space_ga,2,10,20,8,1e-05,0.23150818049907684,0,100 +space_ga,2,10,20,8,1e-05,0.025417599827051163,9,100 +space_ga,2,10,20,8,1e-05,0.33422788977622986,16,100 +space_ga,2,10,20,8,1e-05,0.05376894772052765,0,100 +space_ga,2,10,20,8,1e-05,0.026506789028644562,0,100 +space_ga,2,10,20,8,1e-05,0.044850919395685196,0,100 +space_ga,2,10,20,8,1e-05,0.025532562285661697,0,100 +space_ga,2,10,20,8,1e-05,0.04196251928806305,20,100 diff --git a/baselines/hfedxgboost/results_centralized.csv b/baselines/hfedxgboost/results_centralized.csv new file mode 100644 index 000000000000..bf276c4cd535 --- /dev/null +++ b/baselines/hfedxgboost/results_centralized.csv @@ -0,0 +1,2 @@ +dataset_name,n_estimators_client,xgb_max_depth,subsample,learning_rate,colsample_bylevel,colsample_bynode,colsample_bytree,alpha,gamma,num_parallel_tree,min_child_weight,result_train,result_test +abalone,200,3,0.4,0.05,1,1,1,5,10,1,5,3.9291864339707256,4.404329529975106 diff --git a/datasets/dev/build-flwr-datasets-docs.sh b/datasets/dev/build-flwr-datasets-docs.sh index dc3cd979d5c8..aefa47f147f8 100755 --- a/datasets/dev/build-flwr-datasets-docs.sh +++ b/datasets/dev/build-flwr-datasets-docs.sh @@ -28,3 +28,6 @@ fi # Note if a package cannot be reach via the recursive traversal, even if it has __all__, it won't be documented. echo "Generating the docs based on only the functionality given in the __all__." sphinx-build -M html source build + +# Remove the autogenerated source files after the build +rm source/ref-api/*.rst diff --git a/datasets/doc/source/_templates/autosummary/base.rst b/datasets/doc/source/_templates/autosummary/base.rst new file mode 100644 index 000000000000..5536fa10863f --- /dev/null +++ b/datasets/doc/source/_templates/autosummary/base.rst @@ -0,0 +1,5 @@ +{{ name | escape | underline}} + +.. currentmodule:: {{ module }} + +.. auto{{ objtype }}:: {{ objname }} diff --git a/datasets/doc/source/conf.py b/datasets/doc/source/conf.py index 32baa6dd1471..53d7046bfaac 100644 --- a/datasets/doc/source/conf.py +++ b/datasets/doc/source/conf.py @@ -14,6 +14,7 @@ # ============================================================================== +import datetime import os import sys from sphinx.application import ConfigError @@ -33,7 +34,7 @@ # -- Project information ----------------------------------------------------- project = "Flower Datasets" -copyright = "2023 Flower Labs GmbH" +copyright = f"{datetime.date.today().year} Flower Labs GmbH" author = "The Flower Authors" # The full version, including alpha/beta/rc tags @@ -74,25 +75,27 @@ # The full name is still at the top of the page add_module_names = False + def find_test_modules(package_path): """Go through the python files and exclude every *_test.py file.""" full_path_modules = [] for root, dirs, files in os.walk(package_path): for file in files: - if file.endswith('_test.py'): + if file.endswith("_test.py"): # Construct the module path relative to the package directory full_path = os.path.join(root, file) relative_path = os.path.relpath(full_path, package_path) # Convert file path to dotted module path - module_path = os.path.splitext(relative_path)[0].replace(os.sep, '.') + module_path = os.path.splitext(relative_path)[0].replace(os.sep, ".") full_path_modules.append(module_path) modules = [] for full_path_module in full_path_modules: - parts = full_path_module.split('.') + parts = full_path_module.split(".") for i in range(len(parts)): - modules.append('.'.join(parts[i:])) + modules.append(".".join(parts[i:])) return modules + # Stop from documenting the *_test.py files. # That's the only way to do that in autosummary (make the modules as mock_imports). autodoc_mock_imports = find_test_modules(os.path.abspath("../../")) diff --git a/datasets/e2e/pytorch/pyproject.toml b/datasets/e2e/pytorch/pyproject.toml new file mode 100644 index 000000000000..4565cce9f828 --- /dev/null +++ b/datasets/e2e/pytorch/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "fds-e2e-pytorch" +version = "0.1.0" +description = "Flower Datasets with PyTorch" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = "^3.8" +flwr-datasets = { path = "./../../", extras = ["vision"] } +torch = "^1.12.0" +torchvision = "^0.14.1" +parameterized = "==0.9.0" diff --git a/datasets/e2e/pytorch/pytorch_test.py b/datasets/e2e/pytorch/pytorch_test.py new file mode 100644 index 000000000000..5bac8f770f23 --- /dev/null +++ b/datasets/e2e/pytorch/pytorch_test.py @@ -0,0 +1,131 @@ +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from datasets.utils.logging import disable_progress_bar +from parameterized import parameterized_class, parameterized +from torch import Tensor +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, ToTensor, Normalize + +from flwr_datasets import FederatedDataset + + +class SimpleCNN(nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +# Using parameterized testing, two different sets of parameters are specified: +# 1. CIFAR10 dataset with the simple ToTensor transform. +# 2. CIFAR10 dataset with a composed transform that first converts an image to a tensor +# and then normalizes it. +@parameterized_class( + [ + {"dataset_name": "cifar10", "test_split": "test", "transforms": ToTensor()}, + {"dataset_name": "cifar10", "test_split": "test", "transforms": Compose( + [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + )}, + ] +) +class FdsToPyTorch(unittest.TestCase): + """Test the conversion from FDS to PyTorch Dataset and Dataloader.""" + + dataset_name = "" + test_split = "" + transforms = None + trainloader = None + expected_img_shape_after_transform = [3, 32, 32] + + @classmethod + def setUpClass(cls): + """Disable progress bar to keep the log clean. + """ + disable_progress_bar() + + def _create_trainloader(self, batch_size: int) -> DataLoader: + """Create a trainloader from the federated dataset.""" + partition_id = 0 + fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 100}) + partition = fds.load_partition(partition_id, "train") + partition_train_test = partition.train_test_split(test_size=0.2) + partition_train_test = partition_train_test.map( + lambda img: {"img": self.transforms(img)}, input_columns="img" + ) + trainloader = DataLoader( + partition_train_test["train"].with_format("torch"), batch_size=batch_size, + shuffle=True + ) + return trainloader + + def test_create_partition_dataloader_with_transforms_shape(self) -> None: + """Test if the DataLoader returns batches with the expected shape.""" + batch_size = 16 + trainloader = self._create_trainloader(batch_size) + batch = next(iter(trainloader)) + images = batch["img"] + self.assertEqual(tuple(images.shape), + (batch_size, *self.expected_img_shape_after_transform)) + + def test_create_partition_dataloader_with_transforms_batch_type(self) -> None: + """Test if the DataLoader returns batches of type dictionary.""" + batch_size = 16 + trainloader = self._create_trainloader(batch_size) + batch = next(iter(trainloader)) + self.assertIsInstance(batch, dict) + + def test_create_partition_dataloader_with_transforms_data_type(self) -> None: + """Test to verify if the data in the DataLoader batches are of type Tensor.""" + batch_size = 16 + trainloader = self._create_trainloader(batch_size) + batch = next(iter(trainloader)) + images = batch["img"] + self.assertIsInstance(images, Tensor) + + @parameterized.expand([ + ("not_nan", torch.isnan), + ("not_inf", torch.isinf), + ]) + def test_train_model_loss_value(self, name, condition_func): + """Test if the model trains and if the loss is a correct number.""" + trainloader = self._create_trainloader(16) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Create the model, criterion, and optimizer + net = SimpleCNN().to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + # Training loop for one epoch + net.train() + loss = None + for i, data in enumerate(trainloader, 0): + inputs, labels = data['img'].to(device), data['label'].to(device) + optimizer.zero_grad() + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + self.assertFalse(condition_func(loss).item()) + + +if __name__ == '__main__': + unittest.main() diff --git a/datasets/e2e/tensorflow/pyproject.toml b/datasets/e2e/tensorflow/pyproject.toml index 9c5c72c46400..64c67c7695a5 100644 --- a/datasets/e2e/tensorflow/pyproject.toml +++ b/datasets/e2e/tensorflow/pyproject.toml @@ -9,7 +9,8 @@ description = "Flower Datasets with TensorFlow" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = "^3.8" +python = ">=3.8,<3.11" flwr-datasets = { path = "./../../", extras = ["vision"] } tensorflow-cpu = "^2.9.1, !=2.11.1" +tensorflow-io-gcs-filesystem = "<0.35.0" parameterized = "==0.9.0" diff --git a/datasets/e2e/tensorflow/tensorflow_test.py b/datasets/e2e/tensorflow/tensorflow_test.py index e041bcb8f8cc..5e21b5d6386f 100644 --- a/datasets/e2e/tensorflow/tensorflow_test.py +++ b/datasets/e2e/tensorflow/tensorflow_test.py @@ -38,8 +38,7 @@ class FdsToTensorFlow(unittest.TestCase): @classmethod def setUpClass(cls): - """Disable progress bar to keep the log clean. - """ + """Disable progress bar to keep the log clean.""" disable_progress_bar() def _create_tensorflow_dataset(self, batch_size: int) -> tf.data.Dataset: diff --git a/datasets/flwr_datasets/common/typing.py b/datasets/flwr_datasets/common/typing.py new file mode 100644 index 000000000000..ffaefaeec313 --- /dev/null +++ b/datasets/flwr_datasets/common/typing.py @@ -0,0 +1,26 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Datasets type definitions.""" + + +from typing import Any, List + +import numpy as np +import numpy.typing as npt + +NDArray = npt.NDArray[Any] +NDArrayInt = npt.NDArray[np.int_] +NDArrayFloat = npt.NDArray[np.float_] +NDArrays = List[NDArray] diff --git a/dev/format.sh b/dev/format.sh index a5bc32915545..6b9cdaf5f44c 100755 --- a/dev/format.sh +++ b/dev/format.sh @@ -18,7 +18,7 @@ python -m docformatter -i -r examples # Notebooks python -m black --ipynb -q doc/source/*.ipynb -KEYS="metadata.celltoolbar metadata.language_info metadata.toc metadata.notify_time metadata.varInspector metadata.accelerator metadata.vscode cell.metadata.id cell.metadata.heading_collapsed cell.metadata.hidden cell.metadata.code_folding cell.metadata.tags cell.metadata.init_cell cell.metadata.vscode" +KEYS="metadata.celltoolbar metadata.language_info metadata.toc metadata.notify_time metadata.varInspector metadata.accelerator metadata.vscode cell.metadata.id cell.metadata.heading_collapsed cell.metadata.hidden cell.metadata.code_folding cell.metadata.tags cell.metadata.init_cell cell.metadata.vscode cell.metadata.pycharm" python -m nbstripout doc/source/*.ipynb --extra-keys "$KEYS" python -m nbstripout examples/*/*.ipynb --extra-keys "$KEYS" diff --git a/dev/test-copyright.sh b/dev/test-copyright.sh new file mode 100755 index 000000000000..5ba8fedfe1ef --- /dev/null +++ b/dev/test-copyright.sh @@ -0,0 +1,15 @@ +#!/bin/bash -e + +cd "$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"/../ + +EXIT_CODE=0 + +while IFS= read -r -d '' file; do + COPYRIGHT=$(grep -h -r "copyright = " "$file") + if [ "$COPYRIGHT" != "copyright = f\"{datetime.date.today().year} Flower Labs GmbH\"" ]; then + echo "::error file=$file::Wrong copyright line. Expected: copyright = f\"{datetime.date.today().year} Flower Labs GmbH\"" + EXIT_CODE=1 + fi +done < <(find . -path "*/doc/source/conf.py" -print0) + +exit $EXIT_CODE diff --git a/doc/build-versioned-docs.sh b/doc/build-versioned-docs.sh index 0d7185b5e470..b61472d216e9 100755 --- a/doc/build-versioned-docs.sh +++ b/doc/build-versioned-docs.sh @@ -76,6 +76,9 @@ END # Restore branch as it was to avoid conflicts git restore source/_templates + git restore source/_templates/autosummary || rm -rf source/_templates/autosummary + rm source/ref-api/*.rst + if [ "$current_version" = "v1.5.0" ]; then git restore source/conf.py fi @@ -100,6 +103,7 @@ for current_language in ${languages}; do export current_language sphinx-build -b html source/ build/html/${current_version}/${current_language} -A lang=True -D language=${current_language} done +rm source/ref-api/*.rst # Copy main version to the root of the built docs cp -r build/html/main/en/* build/html/ diff --git a/doc/source/_static/flower-architecture.drawio.png b/doc/source/_static/flower-architecture.drawio.png index 258c6a2727ea..3f82332a0a32 100755 Binary files a/doc/source/_static/flower-architecture.drawio.png and b/doc/source/_static/flower-architecture.drawio.png differ diff --git a/doc/source/_templates/autosummary/base.rst b/doc/source/_templates/autosummary/base.rst new file mode 100644 index 000000000000..5536fa10863f --- /dev/null +++ b/doc/source/_templates/autosummary/base.rst @@ -0,0 +1,5 @@ +{{ name | escape | underline}} + +.. currentmodule:: {{ module }} + +.. auto{{ objtype }}:: {{ objname }} diff --git a/doc/source/_templates/autosummary/class.rst b/doc/source/_templates/autosummary/class.rst new file mode 100644 index 000000000000..b4b35789bc6f --- /dev/null +++ b/doc/source/_templates/autosummary/class.rst @@ -0,0 +1,33 @@ +{{ name | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :show-inheritance: + :inherited-members: + + {% block methods %} + + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + {% if item != "__init__" %} + ~{{ name }}.{{ item }} + {% endif %} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Attributes') }} + + .. autosummary:: + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/doc/source/_templates/autosummary/module.rst b/doc/source/_templates/autosummary/module.rst new file mode 100644 index 000000000000..571db198d27c --- /dev/null +++ b/doc/source/_templates/autosummary/module.rst @@ -0,0 +1,66 @@ +{{ name | escape | underline}} + +.. automodule:: {{ fullname }} + + {% block attributes %} + {% if attributes %} + .. rubric:: Module Attributes + + .. autosummary:: + :toctree: + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: {{ _('Functions') }} + + .. autosummary:: + :toctree: + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: {{ _('Classes') }} + + .. autosummary:: + :toctree: + :template: autosummary/class.rst + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: {{ _('Exceptions') }} + + .. autosummary:: + :toctree: + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + +{% block modules %} +{% if modules %} +.. rubric:: Modules + +.. autosummary:: + :toctree: + :template: autosummary/module.rst + :recursive: +{% for item in modules %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} diff --git a/doc/source/_templates/sidebar/versioning.html b/doc/source/_templates/sidebar/versioning.html index de607b55373a..dde7528d15e4 100644 --- a/doc/source/_templates/sidebar/versioning.html +++ b/doc/source/_templates/sidebar/versioning.html @@ -57,6 +57,26 @@ } + +
@@ -64,7 +84,13 @@
diff --git a/doc/source/conf.py b/doc/source/conf.py index 8077d26aa6ae..503f76cb9eca 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -14,6 +14,7 @@ # ============================================================================== +import datetime import os import sys from git import Repo @@ -81,7 +82,7 @@ # -- Project information ----------------------------------------------------- project = "Flower" -copyright = "2022 Flower Labs GmbH" +copyright = f"{datetime.date.today().year} Flower Labs GmbH" author = "The Flower Authors" # The full version, including alpha/beta/rc tags @@ -95,6 +96,7 @@ extensions = [ "sphinx.ext.napoleon", "sphinx.ext.autodoc", + "sphinx.ext.autosummary", "sphinx.ext.mathjax", "sphinx.ext.viewcode", "sphinx.ext.graphviz", @@ -108,6 +110,42 @@ "nbsphinx", ] +# Generate .rst files +autosummary_generate = True + +# Document ONLY the objects from __all__ (present in __init__ files). +# It will be done recursively starting from flwr.__init__ +# Starting point is controlled in the index.rst file. +autosummary_ignore_module_all = False + +# Each class and function docs start with the path to it +# Make the flwr_datasets.federated_dataset.FederatedDataset appear as FederatedDataset +# The full name is still at the top of the page +add_module_names = False + +def find_test_modules(package_path): + """Go through the python files and exclude every *_test.py file.""" + full_path_modules = [] + for root, dirs, files in os.walk(package_path): + for file in files: + if file.endswith('_test.py'): + # Construct the module path relative to the package directory + full_path = os.path.join(root, file) + relative_path = os.path.relpath(full_path, package_path) + # Convert file path to dotted module path + module_path = os.path.splitext(relative_path)[0].replace(os.sep, '.') + full_path_modules.append(module_path) + modules = [] + for full_path_module in full_path_modules: + parts = full_path_module.split('.') + for i in range(len(parts)): + modules.append('.'.join(parts[i:])) + return modules + +# Stop from documenting the *_test.py files. +# That's the only way to do that in autosummary (make the modules as mock_imports). +autodoc_mock_imports = find_test_modules(os.path.abspath("../../src/py/flwr")) + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -173,7 +211,8 @@ "evaluation": "explanation-federated-evaluation.html", "differential-privacy-wrappers": "explanation-differential-privacy.html", # Restructuring: references - "apiref-flwr": "ref-api-flwr.html", + "apiref-flwr": "ref-api/flwr.html", + "ref-api-flwr": "ref-api/flwr.html", "apiref-cli": "ref-api-cli.html", "examples": "ref-example-projects.html", "telemetry": "ref-telemetry.html", diff --git a/doc/source/contributor-how-to-build-docker-images.rst b/doc/source/contributor-how-to-build-docker-images.rst new file mode 100644 index 000000000000..d85e48155de0 --- /dev/null +++ b/doc/source/contributor-how-to-build-docker-images.rst @@ -0,0 +1,135 @@ +How to build Docker Flower images locally +========================================= + +Flower provides pre-made docker images on `Docker Hub `_ +that include all necessary dependencies for running the server. You can also build your own custom +docker images from scratch with a different version of Python or Ubuntu if that is what you need. +In this guide, we will explain what images exist and how to build them locally. + +Before we can start, we need to meet a few prerequisites in our local development environment. + +#. Clone the flower repository. + + .. code-block:: bash + + $ git clone https://github.com/adap/flower.git && cd flower + +#. Verify the Docker daemon is running. + + Please follow the first section on + `Run Flower using Docker `_ + which covers this step in more detail. + +Currently, Flower provides two images, a base image and a server image. There will also be a client +image soon. The base image, as the name suggests, contains basic dependencies that both the server +and the client need. This includes system dependencies, Python and Python tools. The server image is +based on the base image, but it additionally installs the Flower server using ``pip``. + +The build instructions that assemble the images are located in the respective Dockerfiles. You +can find them in the subdirectories of ``src/docker``. + +Both, base and server image are configured via build arguments. Through build arguments, we can make +our build more flexible. For example, in the base image, we can specify the version of Python to +install using the ``PYTHON_VERSION`` build argument. Some of the build arguments have default +values, others must be specified when building the image. All available build arguments for each +image are listed in one of the tables below. + +Building the base image +----------------------- + +.. list-table:: + :widths: 25 45 15 15 + :header-rows: 1 + + * - Build argument + - Description + - Required + - Example + * - ``PYTHON_VERSION`` + - Version of ``python`` to be installed. + - Yes + - ``3.11`` + * - ``PIP_VERSION`` + - Version of ``pip`` to be installed. + - Yes + - ``23.0.1`` + * - ``SETUPTOOLS_VERSION`` + - Version of ``setuptools`` to be installed. + - Yes + - ``69.0.2`` + * - ``UBUNTU_VERSION`` + - Version of the official Ubuntu Docker image. + - Defaults to ``22.04``. + - + +The following example creates a base image with Python 3.11.0, pip 23.0.1 and setuptools 69.0.2: + +.. code-block:: bash + + $ cd src/docker/base/ + $ docker build \ + --build-arg PYTHON_VERSION=3.11.0 \ + --build-arg PIP_VERSION=23.0.1 \ + --build-arg SETUPTOOLS_VERSION=69.0.2 \ + -t flwr_base:0.1.0 . + +The name of image is ``flwr_base`` and the tag ``0.1.0``. Remember that the build arguments as well +as the name and tag can be adapted to your needs. These values serve as examples only. + +Building the server image +------------------------- + +.. list-table:: + :widths: 25 45 15 15 + :header-rows: 1 + + * - Build argument + - Description + - Required + - Example + * - ``BASE_REPOSITORY`` + - The repository name of the base image. + - Defaults to ``flwr/server``. + - + * - ``BASE_IMAGE_TAG`` + - The image tag of the base image. + - Defaults to ``py3.11-ubuntu22.04``. + - + * - ``FLWR_VERSION`` + - Version of Flower to be installed. + - Yes + - ``1.6.0`` + +The following example creates a server image with the official Flower base image py3.11-ubuntu22.04 +and Flower 1.6.0: + +.. code-block:: bash + + $ cd src/docker/server/ + $ docker build \ + --build-arg BASE_IMAGE_TAG=py3.11-ubuntu22.04 \ + --build-arg FLWR_VERSION=1.6.0 \ + -t flwr_server:0.1.0 . + +The name of image is ``flwr_server`` and the tag ``0.1.0``. Remember that the build arguments as well +as the name and tag can be adapted to your needs. These values serve as examples only. + +If you want to use your own base image instead of the official Flower base image, all you need to do +is set the ``BASE_REPOSITORY`` and ``BASE_IMAGE_TAG`` build arguments. The value of +``BASE_REPOSITORY`` must match the name of your image and the value of ``BASE_IMAGE_TAG`` must match +the tag of your image. + +.. code-block:: bash + + $ cd src/docker/server/ + $ docker build \ + --build-arg BASE_REPOSITORY=flwr_base \ + --build-arg BASE_IMAGE_TAG=0.1.0 \ + --build-arg FLWR_VERSION=1.6.0 \ + -t flwr_server:0.1.0 . + +After creating the image, we can test whether the image is working: + +.. code-block:: bash + + $ docker run --rm flwr_server:0.1.0 --help diff --git a/doc/source/contributor-how-create-new-messages.rst b/doc/source/contributor-how-to-create-new-messages.rst similarity index 100% rename from doc/source/contributor-how-create-new-messages.rst rename to doc/source/contributor-how-to-create-new-messages.rst diff --git a/doc/source/contributor-how-to-release-flower.rst b/doc/source/contributor-how-to-release-flower.rst index 2eef165c0ed0..acfac4197ec1 100644 --- a/doc/source/contributor-how-to-release-flower.rst +++ b/doc/source/contributor-how-to-release-flower.rst @@ -3,23 +3,15 @@ Release Flower This document describes the current release process. It may or may not change in the future. -Before the release ------------------- - -Update the changelog (``changelog.md``) with all relevant changes that happened after the last release. If the last release was tagged ``v1.2.0``, you can use the following URL to see all commits that got merged into ``main`` since then: - -`GitHub: Compare v1.2.0...main `_ - -Thank the authors who contributed since the last release. This can be done by running the ``./dev/add-shortlog.sh `` convenience script (it can be ran multiple times and will update the names in the list if new contributors were added in the meantime). - During the release ------------------ The version number of a release is stated in ``pyproject.toml``. To release a new version of Flower, the following things need to happen (in that order): -1. Update the ``changelog.md`` section header ``Unreleased`` to contain the version number and date for the release you are building. Create a pull request with the change. -2. Tag the release commit with the version number as soon as the PR is merged: ``git tag v0.12.3``, then ``git push --tags``. This will create a draft release on GitHub containing the correct artifacts and the relevant part of the changelog. -3. Check the draft release on GitHub, and if everything is good, publish it. +1. Run ``python3 src/py/flwr_tool/update_changelog.py `` in order to add every new change to the changelog (feel free to make manual changes to the changelog afterwards until it looks good). +2. Once the changelog has been updated with all the changes, run ``./dev/prepare-release-changelog.sh v``, where ```` is the version stated in ``pyproject.toml`` (notice the ``v`` added before it). This will replace the ``Unreleased`` header of the changelog by the version and current date, and it will add a thanking message for the contributors. Open a pull request with those changes. +3. Once the pull request is merged, tag the release commit with the version number as soon as the PR is merged: ``git tag v`` (notice the ``v`` added before the version number), then ``git push --tags``. This will create a draft release on GitHub containing the correct artifacts and the relevant part of the changelog. +4. Check the draft release on GitHub, and if everything is good, publish it. After the release ----------------- diff --git a/doc/source/contributor-tutorial-contribute-on-github.rst b/doc/source/contributor-tutorial-contribute-on-github.rst index 9aeb8229b412..351d2408d9f3 100644 --- a/doc/source/contributor-tutorial-contribute-on-github.rst +++ b/doc/source/contributor-tutorial-contribute-on-github.rst @@ -46,7 +46,7 @@ Setting up the repository $ git clone - This will create a `flower/` (or the name of your fork if you renamed it) folder in the current working directory. + This will create a ``flower/`` (or the name of your fork if you renamed it) folder in the current working directory. 4. **Add origin** You can then go into the repository folder: @@ -180,9 +180,9 @@ Creating and merging a pull request (PR) .. image:: _static/compare_and_pr.png - Otherwise you can always find this option in the `Branches` page. + Otherwise you can always find this option in the ``Branches`` page. - Once you click the `Compare & pull request` button, you should see something similar to this: + Once you click the ``Compare & pull request`` button, you should see something similar to this: .. image:: _static/creating_pr.png @@ -195,6 +195,10 @@ Creating and merging a pull request (PR) The input box in the middle is there for you to describe what your PR does and to link it to existing issues. We have placed comments (that won't be rendered once the PR is opened) to guide you through the process. + It is important to follow the instructions described in comments. For instance, in order to not break how our changelog system works, + you should read the information above the ``Changelog entry`` section carefully. + You can also checkout some examples and details in the :ref:`changelogentry` appendix. + At the bottom you will find the button to open the PR. This will notify reviewers that a new PR has been opened and that they should look over it to merge or to request changes. @@ -272,8 +276,8 @@ Solution This is a tiny change, but it’ll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here’s what you should do: -- Find the source file in `doc/source` -- Make the change in the `.rst` file (beware, the dashes under the title should be the same length as the title itself) +- Find the source file in ``doc/source`` +- Make the change in the ``.rst`` file (beware, the dashes under the title should be the same length as the title itself) - Build the docs and check the result: ``_ Rename file @@ -284,18 +288,18 @@ If we just change the file, then we break all existing links to it - it is **ver Here’s how to change the file name: -- Change the file name to `save-progress.rst` -- Add a redirect rule to `doc/source/conf.py` +- Change the file name to ``save-progress.rst`` +- Add a redirect rule to ``doc/source/conf.py`` -This will cause a redirect from `saving-progress.html` to `save-progress.html`, old links will continue to work. +This will cause a redirect from ``saving-progress.html`` to ``save-progress.html``, old links will continue to work. Apply changes in the index file ::::::::::::::::::::::::::::::: -For the lateral navigation bar to work properly, it is very important to update the `index.rst` file as well. +For the lateral navigation bar to work properly, it is very important to update the ``index.rst`` file as well. This is where we define the whole arborescence of the navbar. -- Find and modify the file name in `index.rst` +- Find and modify the file name in ``index.rst`` Open PR ::::::: @@ -331,7 +335,7 @@ Here are a few positive examples which provide helpful information without repea * Update docs banner to mention Flower Summit 2023 * Remove unnecessary XGBoost dependency * Remove redundant attributes in strategies subclassing FedAvg -* Add CI job to deploy the staging system when the `main` branch changes +* Add CI job to deploy the staging system when the ``main`` branch changes * Add new amazing library which will be used to improve the simulation engine @@ -341,3 +345,83 @@ Next steps Once you have made your first PR, and want to contribute more, be sure to check out the following : - `Good first contributions `_, where you should particularly look into the :code:`baselines` contributions. + + +Appendix +-------- + +.. _changelogentry: + +Changelog entry +*************** + +When opening a new PR, inside its description, there should be a ``Changelog entry`` header. + +Above this header you should see the following comment that explains how to write your changelog entry: + + Inside the following 'Changelog entry' section, + you should put the description of your changes that will be added to the changelog alongside your PR title. + + If the section is completely empty (without any token), + the changelog will just contain the title of the PR for the changelog entry, without any description. + + If the 'Changelog entry' section is removed entirely, + it will categorize the PR as "General improvement" and add it to the changelog accordingly. + + If the section contains some text other than tokens, it will use it to add a description to the change. + + If the section contains one of the following tokens it will ignore any other text and put the PR under the corresponding section of the changelog: + + is for classifying a PR as a general improvement. + + is to not add the PR to the changelog + + is to add a general baselines change to the PR + + is to add a general examples change to the PR + + is to add a general sdk change to the PR + + is to add a general simulations change to the PR + + Note that only one token should be used. + +Its content must have a specific format. We will break down what each possibility does: + +- If the ``### Changelog entry`` section is removed, the following text will be added to the changelog:: + + - **General improvements** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains nothing but exists, the following text will be added to the changelog:: + + - **PR TITLE** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains a description (and no token), the following text will be added to the changelog:: + + - **PR TITLE** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + + DESCRIPTION FROM THE CHANGELOG ENTRY + +- If the ``### Changelog entry`` section contains ````, nothing will change in the changelog. + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General improvements** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General updates to Flower Baselines** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General updates to Flower Examples** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General updates to Flower SDKs** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +- If the ``### Changelog entry`` section contains ````, the following text will be added to the changelog:: + + - **General updates to Flower Simulations** ([#PR_NUMBER](https://github.com/adap/flower/pull/PR_NUMBER)) + +Note that only one token must be provided, otherwise, only the first action (in the order listed above), will be performed. diff --git a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst index e035d1b9867d..72c6df5fdbc7 100644 --- a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst +++ b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst @@ -4,7 +4,7 @@ Get started as a contributor Prerequisites ------------- -- `Python 3.7 `_ or above +- `Python 3.8 `_ or above - `Poetry 1.3 `_ or above - (Optional) `pyenv `_ - (Optional) `pyenv-virtualenv `_ @@ -17,26 +17,49 @@ supports `PEP 517 `_. Developer Machine Setup ----------------------- -First, clone the `Flower repository `_ from +Preliminarities +~~~~~~~~~~~~~~~ +Some system-wide dependencies are needed. + +For macOS +^^^^^^^^^ + +* Install `homebrew `_. Don't forget the post-installation actions to add `brew` to your PATH. +* Install `xz` (to install different Python versions) and `pandoc` to build the + docs:: + + $ brew install xz pandoc + +For Ubuntu +^^^^^^^^^^ +Ensure you system (Ubuntu 22.04+) is up-to-date, and you have all necessary +packages:: + + $ apt update + $ apt install build-essential zlib1g-dev libssl-dev libsqlite3-dev \ + libreadline-dev libbz2-dev libffi-dev liblzma-dev pandoc + + +Create Flower Dev Environment +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +1. Clone the `Flower repository `_ from GitHub:: $ git clone git@github.com:adap/flower.git $ cd flower -Second, create a virtual environment (and activate it). If you chose to use -:code:`pyenv` (with the :code:`pyenv-virtualenv` plugin) and already have it installed -, you can use the following convenience script (by default it will use :code:`Python 3.8.17`, -but you can change it by providing a specific :code:``):: +2. Let's create the Python environment for all-things Flower. If you wish to use :code:`pyenv`, we provide two convenience scripts that you can use. If you prefer using something else than :code:`pyenv`, create a new environment, activate and skip to the last point where all packages are installed. - $ ./dev/venv-create.sh +* If you don't have :code:`pyenv` installed, the following script that will install it, set it up, and create the virtual environment (with :code:`Python 3.8.17` by default):: -If you don't have :code:`pyenv` installed, -you can use the following script that will install pyenv, -set it up and create the virtual environment (with :code:`Python 3.8.17` by default):: + $ ./dev/setup-defaults.sh # once completed, run the bootstrap script + +* If you already have :code:`pyenv` installed (along with the :code:`pyenv-virtualenv` plugin), you can use the following convenience script (with :code:`Python 3.8.17` by default):: - $ ./dev/setup-defaults.sh + $ ./dev/venv-create.sh # once completed, run the `bootstrap.sh` script -Third, install the Flower package in development mode (think +3. Install the Flower package in development mode (think :code:`pip install -e`) along with all necessary dependencies:: (flower-) $ ./dev/bootstrap.sh diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index b2efde176fc9..ff3dbb605846 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 `_, but `Pyth Install stable release ---------------------- +Using pip +~~~~~~~~~ + Stable releases are available on `PyPI `_:: python -m pip install flwr @@ -20,10 +23,29 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed python -m pip install flwr[simulation] +Using conda (or mamba) +~~~~~~~~~~~~~~~~~~~~~~ + +Flower can also be installed from the ``conda-forge`` channel. + +If you have not added ``conda-forge`` to your channels, you will first need to run the following:: + + conda config --add channels conda-forge + conda config --set channel_priority strict + +Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``:: + + conda install flwr + +or with ``mamba``:: + + mamba install flwr + + Verify installation ------------------- -The following command can be used to verfiy if Flower was successfully installed. If everything worked, it should print the version of Flower to the command line:: +The following command can be used to verify if Flower was successfully installed. If everything worked, it should print the version of Flower to the command line:: python -c "import flwr;print(flwr.__version__)" 1.5.0 @@ -32,6 +54,11 @@ The following command can be used to verfiy if Flower was successfully installed Advanced installation options ----------------------------- +Install via Docker +~~~~~~~~~~~~~~~~~~ + +`How to run Flower using Docker `_ + Install pre-release ~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/how-to-run-flower-using-docker.rst b/doc/source/how-to-run-flower-using-docker.rst new file mode 100644 index 000000000000..27ff61c280cb --- /dev/null +++ b/doc/source/how-to-run-flower-using-docker.rst @@ -0,0 +1,144 @@ +Run Flower using Docker +==================== + +The simplest way to get started with Flower is by using the pre-made Docker images, which you can +find on `Docker Hub `_. + +Before you start, make sure that the Docker daemon is running: + +.. code-block:: bash + + $ docker -v + Docker version 24.0.7, build afdd53b + +If you do not see the version of Docker but instead get an error saying that the command +was not found, you will need to install Docker first. You can find installation instruction +`here `_. + +.. note:: + + On Linux, Docker commands require ``sudo`` privilege. If you want to avoid using ``sudo``, + you can follow the `Post-installation steps `_ + on the official Docker website. + +Flower server +------------- + +Quickstart +~~~~~~~~~~ + +If you're looking to try out Flower, you can use the following command: + +.. code-block:: bash + + $ docker run --rm -p 9091:9091 -p 9092:9092 flwr/server:1.6.0-py3.11-ubuntu22.04 \ + --insecure + +The command will pull the Docker image with the tag ``1.6.0-py3.11-ubuntu22.04`` from Docker Hub. +The tag contains the information which Flower, Python and Ubuntu is used. In this case, it +uses Flower 1.6.0, Python 3.11 and Ubuntu 22.04. The ``--rm`` flag tells Docker to remove +the container after it exits. + +.. note:: + + By default, the Flower server keeps state in-memory. When using the Docker flag + ``--rm``, the state is not persisted between container starts. We will show below how to save the + state in a file on your host system. + +The ``-p :`` flag tells Docker to map the ports ``9091``/``9092`` of the host to +``9091``/``9092`` of the container, allowing you to access the Driver API on ``http://localhost:9091`` +and the Fleet API on ``http://localhost:9092``. Lastly, any flag that comes after the tag is passed +to the Flower server. Here, we are passing the flag ``--insecure``. + +.. attention:: + + The ``--insecure`` flag enables insecure communication (using HTTP, not HTTPS) and should only be used + for testing purposes. We strongly recommend enabling + `SSL `_ + when deploying to a production environment. + +You can use ``--help`` to view all available flags that the server supports: + +.. code-block:: bash + + $ docker run --rm flwr/server:1.6.0-py3.11-ubuntu22.04 --help + +Mounting a volume to store the state on the host system +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to persist the state of the server on your host system, all you need to do is specify a +path where you want to save the file on your host system and a name for the database file. In the +example below, we tell Docker via the flag ``-v`` to mount the user's home directory +(``~/`` on your host) into the ``/app/`` directory of the container. Furthermore, we use the +flag ``--database`` to specify the name of the database file. + +.. code-block:: bash + + $ docker run --rm \ + -p 9091:9091 -p 9092:9092 -v ~/:/app/ flwr/server:1.6.0-py3.11-ubuntu22.04 \ + --insecure \ + --database state.db + +As soon as the server starts, the file ``state.db`` is created in the user's home directory on +your host system. If the file already exists, the server tries to restore the state from the file. +To start the server with an empty database, simply remove the ``state.db`` file. + +Enabling SSL for secure connections +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To enable SSL, you will need a CA certificate, a server certificate and a server private key. + +.. note:: + For testing purposes, you can generate your own self-signed certificates. The + `Enable SSL connections `_ + page contains a section that will guide you through the process. + +Assuming all files we need are in the local ``certificates`` directory, we can use the flag +``-v`` to mount the local directory into the ``/app/`` directory of the container. This allows the +server to access the files within the container. Finally, we pass the names of the certificates to +the server with the ``--certificates`` flag. + +.. code-block:: bash + + $ docker run --rm \ + -p 9091:9091 -p 9092:9092 -v ./certificates/:/app/ flwr/server:1.6.0-py3.11-ubuntu22.04 \ + --certificates ca.crt server.pem server.key + +Using a different Flower or Python version +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to use a different version of Flower or Python, you can do so by changing the tag. +All versions we provide are available on `Docker Hub `_. + +Pinning a Docker image to a specific version +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +It may happen that we update the images behind the tags. Such updates usually include security +updates of system dependencies that should not change the functionality of Flower. However, if you +want to ensure that you always use the same image, you can specify the hash of the image instead of +the tag. + +The following command returns the current image hash referenced by the ``server:1.6.0-py3.11-ubuntu22.04`` tag: + +.. code-block:: bash + + $ docker inspect --format='{{index .RepoDigests 0}}' flwr/server:1.6.0-py3.11-ubuntu22.04 + flwr/server@sha256:43fc389bcb016feab2b751b2ccafc9e9a906bb0885bd92b972329801086bc017 + +Next, we can pin the hash when running a new server container: + +.. code-block:: bash + + $ docker run \ + --rm flwr/server@sha256:43fc389bcb016feab2b751b2ccafc9e9a906bb0885bd92b972329801086bc017 \ + --insecure + +Setting environment variables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To set a variable inside a Docker container, you can use the ``-e =`` flag. + +.. code-block:: bash + + $ docker run -e FLWR_TELEMETRY_ENABLED=0 \ + --rm flwr/server:1.6.0-py3.11-ubuntu22.04 --insecure diff --git a/doc/source/how-to-save-and-load-model-checkpoints.rst b/doc/source/how-to-save-and-load-model-checkpoints.rst index 404df485fbae..0d711e375cd8 100644 --- a/doc/source/how-to-save-and-load-model-checkpoints.rst +++ b/doc/source/how-to-save-and-load-model-checkpoints.rst @@ -91,3 +91,7 @@ To load your progress, you simply append the following lines to your code. Note print("Loading pre-trained model from: ", latest_round_file) state_dict = torch.load(latest_round_file) net.load_state_dict(state_dict) + state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()] + parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays) + +Return/use this object of type ``Parameters`` wherever necessary, such as in the ``initial_parameters`` when defining a ``Strategy``. \ No newline at end of file diff --git a/doc/source/how-to-use-built-in-middleware-layers.rst b/doc/source/how-to-use-built-in-middleware-layers.rst new file mode 100644 index 000000000000..2e91623b26be --- /dev/null +++ b/doc/source/how-to-use-built-in-middleware-layers.rst @@ -0,0 +1,87 @@ +Use Built-in Middleware Layers +============================== + +**Note: This tutorial covers experimental features. The functionality and interfaces may change in future versions.** + +In this tutorial, we will learn how to utilize built-in middleware layers to augment the behavior of a ``FlowerCallable``. Middleware allows us to perform operations before and after a task is processed in the ``FlowerCallable``. + +What is middleware? +------------------- + +Middleware is a callable that wraps around a ``FlowerCallable``. It can manipulate or inspect incoming tasks (``TaskIns``) in the ``Fwd`` and the resulting tasks (``TaskRes``) in the ``Bwd``. The signature for a middleware layer (``Layer``) is as follows: + +.. code-block:: python + + FlowerCallable = Callable[[Fwd], Bwd] + Layer = Callable[[Fwd, FlowerCallable], Bwd] + +A typical middleware function might look something like this: + +.. code-block:: python + + def example_middleware(fwd: Fwd, ffn: FlowerCallable) -> Bwd: + # Do something with Fwd before passing to the inner ``FlowerCallable``. + bwd = ffn(fwd) + # Do something with Bwd before returning. + return bwd + +Using middleware layers +----------------------- + +To use middleware layers in your ``FlowerCallable``, you can follow these steps: + +1. Import the required middleware +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +First, import the built-in middleware layers you intend to use: + +.. code-block:: python + + import flwr as fl + from flwr.client.middleware import example_middleware1, example_middleware2 + +2. Define your client function +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Define your client function (``client_fn``) that will be wrapped by the middleware: + +.. code-block:: python + + def client_fn(cid): + # Your client code goes here. + return # your client + +3. Create the ``FlowerCallable`` with middleware +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Create your ``FlowerCallable`` and pass the middleware layers as a list to the ``layers`` argument. The order in which you provide the middleware layers matters: + +.. code-block:: python + + flower = fl.app.Flower( + client_fn=client_fn, + layers=[ + example_middleware1, # Middleware layer 1 + example_middleware2, # Middleware layer 2 + ] + ) + +Order of execution +------------------ + +When the ``FlowerCallable`` runs, the middleware layers are executed in the order they are provided in the list: + +1. ``example_middleware1`` (outermost layer) +2. ``example_middleware2`` (next layer) +3. Message handler (core function that handles ``TaskIns`` and returns ``TaskRes``) +4. ``example_middleware2`` (on the way back) +5. ``example_middleware1`` (outermost layer on the way back) + +Each middleware has a chance to inspect and modify the ``TaskIns`` in the ``Fwd`` before passing it to the next layer, and likewise with the ``TaskRes`` in the ``Bwd`` before returning it up the stack. + +Conclusion +---------- + +By following this guide, you have learned how to effectively use middleware layers to enhance your ``FlowerCallable``'s functionality. Remember that the order of middleware is crucial and affects how the input and output are processed. + +Enjoy building more robust and flexible ``FlowerCallable``s with middleware layers! diff --git a/doc/source/index.rst b/doc/source/index.rst index c4a313414d3a..5df591d6ce05 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -91,6 +91,8 @@ Problem-oriented how-to guides show step-by-step how to achieve a specific goal. how-to-configure-logging how-to-enable-ssl-connections how-to-upgrade-to-flower-1.0 + how-to-use-built-in-middleware-layers + how-to-run-flower-using-docker .. toctree:: :maxdepth: 1 @@ -119,11 +121,17 @@ References Information-oriented API reference and other reference material. +.. autosummary:: + :toctree: ref-api + :template: autosummary/module.rst + :caption: API reference + :recursive: + + flwr + .. toctree:: :maxdepth: 2 - :caption: API reference - ref-api-flwr ref-api-cli .. toctree:: @@ -160,6 +168,7 @@ The Flower community welcomes contributions. The following docs are intended to contributor-how-to-write-documentation contributor-how-to-release-flower contributor-how-to-contribute-translations + contributor-how-to-build-docker-images .. toctree:: :maxdepth: 1 diff --git a/doc/source/ref-api-flwr.rst b/doc/source/ref-api-flwr.rst deleted file mode 100644 index e1983cd92c90..000000000000 --- a/doc/source/ref-api-flwr.rst +++ /dev/null @@ -1,265 +0,0 @@ -flwr (Python API reference) -=========================== - - -.. _flwr-client-apiref: - -client ------- - -.. automodule:: flwr.client - -.. _flwr-client-Client-apiref: - -Client -~~~~~~ - -.. autoclass:: flwr.client.Client - :members: - - -.. _flwr-client-start_client-apiref: - -start_client -~~~~~~~~~~~~ - -.. autofunction:: flwr.client.start_client - - -.. _flwr-client-NumPyClient-apiref: - -NumPyClient -~~~~~~~~~~~ - -.. autoclass:: flwr.client.NumPyClient - :members: - - -.. _flwr-client-start_numpy_client-apiref: - -start_numpy_client -~~~~~~~~~~~~~~~~~~ - -.. autofunction:: flwr.client.start_numpy_client - - -.. _flwr-simulation-start_simulation-apiref: - -start_simulation -~~~~~~~~~~~~~~~~~~ - -.. autofunction:: flwr.simulation.start_simulation - - -.. _flwr-server-apiref: - -server ------- - -.. automodule:: flwr.server - - -.. _flwr-server-start_server-apiref: - -server.start_server -~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: flwr.server.start_server - - -.. _flwr-server-strategy-apiref: - -server.strategy -~~~~~~~~~~~~~~~ - -.. automodule:: flwr.server.strategy - - -.. _flwr-server-strategy-Strategy-apiref: - -server.strategy.Strategy -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.Strategy - :members: - - -.. _flwr-server-strategy-FedAvg-apiref: - -server.strategy.FedAvg -^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedAvgM-apiref: - -server.strategy.FedAvgM -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedAvgM - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedMedian-apiref: - -server.strategy.FedMedian -^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedMedian - :members: - - .. automethod:: __init__ - -.. _flwr-server-strategy-QFedAvg-apiref: - -server.strategy.QFedAvg -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.QFedAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FaultTolerantFedAvg-apiref: - -server.strategy.FaultTolerantFedAvg -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FaultTolerantFedAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedOpt-apiref: - -server.strategy.FedOpt -^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedOpt - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedProx-apiref: - -server.strategy.FedProx -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedProx - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedAdagrad-apiref: - -server.strategy.FedAdagrad -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedAdagrad - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedAdam-apiref: - -server.strategy.FedAdam -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedAdam - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedYogi-apiref: - -server.strategy.FedYogi -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedYogi - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedTrimmedAvg-apiref: - -server.strategy.FedTrimmedAvg -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedTrimmedAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-Krum-apiref: - -server.strategy.Krum -^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.Krum - :members: - - .. automethod:: __init__ - -.. _flwr-server-strategy-Bulyan-apiref: - -server.strategy.Bulyan -^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.Bulyan - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-FedXgbNnAvg-apiref: - -server.strategy.FedXgbNnAvg -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.FedXgbNnAvg - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-DPFedAvgAdaptive-apiref: - -server.strategy.DPFedAvgAdaptive -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.DPFedAvgAdaptive - :members: - - .. automethod:: __init__ - - -.. _flwr-server-strategy-DPFedAvgFixed-apiref: - -server.strategy.DPFedAvgFixed -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: flwr.server.strategy.DPFedAvgFixed - :members: - - .. automethod:: __init__ - -common ------- - -.. automodule:: flwr.common - :members: - :exclude-members: event diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 8fbfb4843756..5f323bc80baa 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -2,6 +2,24 @@ ## Unreleased +- **Add scikit-learn tabular data example** ([#2719](https://github.com/adap/flower/pull/2719)) + +- **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381)) + +- **Retiring MXNet examples** The development of the MXNet fremework has ended and the project is now [archived on GitHub](https://github.com/apache/mxnet). Existing MXNet examples won't receive updates [#2724](https://github.com/adap/flower/pull/2724) + +- **Update Flower Baselines** + + - HFedXGBoost [#2226](https://github.com/adap/flower/pull/2226) + + - FedVSSL [#2412](https://github.com/adap/flower/pull/2412) + + - FedNova [#2179](https://github.com/adap/flower/pull/2179) + + - HeteroFL [#2439](https://github.com/adap/flower/pull/2439) + + - FedAvgM [#2246](https://github.com/adap/flower/pull/2246) + ## v1.6.0 (2023-11-28) ### Thanks to our contributors diff --git a/doc/source/tutorial-quickstart-mxnet.rst b/doc/source/tutorial-quickstart-mxnet.rst index 149d060e4c00..ff8d4b2087dd 100644 --- a/doc/source/tutorial-quickstart-mxnet.rst +++ b/doc/source/tutorial-quickstart-mxnet.rst @@ -4,6 +4,8 @@ Quickstart MXNet ================ +.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongise Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. + .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with MXNet to train a Sequential model on MNIST. diff --git a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb old mode 100755 new mode 100644 index 41c9254e9d69..bbd916b32375 --- a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb +++ b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb @@ -9,7 +9,7 @@ "\n", "Welcome to the Flower federated learning tutorial!\n", "\n", - "In this notebook, we'll build a federated learning system using Flower and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n", + "In this notebook, we'll build a federated learning system using Flower, [Flower Datasets](https://flower.dev/docs/datasets/) and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n", "\n", "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", @@ -31,7 +31,7 @@ "source": [ "### Installing dependencies\n", "\n", - "Next, we install the necessary packages for PyTorch (`torch` and `torchvision`) and Flower (`flwr`):" + "Next, we install the necessary packages for PyTorch (`torch` and `torchvision`), Flower Datasets (`flwr-datasets`) and Flower (`flwr`):" ] }, { @@ -40,7 +40,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q flwr[simulation] torch torchvision matplotlib" + "!pip install -q flwr[simulation] flwr_datasets[vision] torch torchvision matplotlib" ] }, { @@ -64,18 +64,19 @@ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", - "import torchvision\n", "import torchvision.transforms as transforms\n", - "from torch.utils.data import DataLoader, random_split\n", - "from torchvision.datasets import CIFAR10\n", + "from datasets.utils.logging import disable_progress_bar\n", + "from torch.utils.data import DataLoader\n", "\n", "import flwr as fl\n", "from flwr.common import Metrics\n", + "from flwr_datasets import FederatedDataset\n", "\n", "DEVICE = torch.device(\"cpu\") # Try \"cuda\" to train on GPU\n", "print(\n", " f\"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}\"\n", - ")" + ")\n", + "disable_progress_bar()" ] }, { @@ -92,27 +93,7 @@ "\n", "### Loading the data\n", "\n", - "Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "CLASSES = (\n", - " \"plane\",\n", - " \"car\",\n", - " \"bird\",\n", - " \"cat\",\n", - " \"deer\",\n", - " \"dog\",\n", - " \"frog\",\n", - " \"horse\",\n", - " \"ship\",\n", - " \"truck\",\n", - ")" + "Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', and 'truck'." ] }, { @@ -121,16 +102,7 @@ "source": [ "We simulate having multiple datasets from multiple organizations (also called the \"cross-silo\" setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. We're doing this purely for experimentation purposes, in the real world there's no need for data splitting because each organization already has their own data (so the data is naturally partitioned).\n", "\n", - "Each organization will act as a client in the federated learning system. So having ten organizations participate in a federation means having ten clients connected to the federated learning server:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "NUM_CLIENTS = 10" + "Each organization will act as a client in the federated learning system. So having ten organizations participate in a federation means having ten clients connected to the federated learning server.\n" ] }, { @@ -138,7 +110,7 @@ "metadata": {}, "source": [ "\n", - "Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap the resulting partitions by creating a PyTorch `DataLoader` for each of them:" + "Let's now create the Federated Dataset abstraction that from `flwr-datasets` that partitions the CIFAR-10. We will create small training and test set for each edge device and wrap each of them into a PyTorch `DataLoader`:" ] }, { @@ -147,32 +119,36 @@ "metadata": {}, "outputs": [], "source": [ + "NUM_CLIENTS = 10\n", "BATCH_SIZE = 32\n", "\n", "\n", "def load_datasets():\n", - " # Download and transform CIFAR-10 (train and test)\n", - " transform = transforms.Compose(\n", - " [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", - " )\n", - " trainset = CIFAR10(\"./dataset\", train=True, download=True, transform=transform)\n", - " testset = CIFAR10(\"./dataset\", train=False, download=True, transform=transform)\n", - "\n", - " # Split training set into 10 partitions to simulate the individual dataset\n", - " partition_size = len(trainset) // NUM_CLIENTS\n", - " lengths = [partition_size] * NUM_CLIENTS\n", - " datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))\n", - "\n", - " # Split each partition into train/val and create DataLoader\n", + " fds = FederatedDataset(dataset=\"cifar10\", partitioners={\"train\": NUM_CLIENTS})\n", + "\n", + " def apply_transforms(batch):\n", + " # Instead of passing transforms to CIFAR10(..., transform=transform)\n", + " # we will use this function to dataset.with_transform(apply_transforms)\n", + " # The transforms object is exactly the same\n", + " transform = transforms.Compose(\n", + " [\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ]\n", + " )\n", + " batch[\"img\"] = [transform(img) for img in batch[\"img\"]]\n", + " return batch\n", + "\n", + " # Create train/val for each partition and wrap it into DataLoader\n", " trainloaders = []\n", " valloaders = []\n", - " for ds in datasets:\n", - " len_val = len(ds) // 10 # 10 % validation set\n", - " len_train = len(ds) - len_val\n", - " lengths = [len_train, len_val]\n", - " ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))\n", - " trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))\n", - " valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))\n", + " for partition_id in range(NUM_CLIENTS):\n", + " partition = fds.load_partition(partition_id, \"train\")\n", + " partition = partition.with_transform(apply_transforms)\n", + " partition = partition.train_test_split(train_size=0.8)\n", + " trainloaders.append(DataLoader(partition[\"train\"], batch_size=BATCH_SIZE))\n", + " valloaders.append(DataLoader(partition[\"test\"], batch_size=BATCH_SIZE))\n", + " testset = fds.load_full(\"test\").with_transform(apply_transforms)\n", " testloader = DataLoader(testset, batch_size=BATCH_SIZE)\n", " return trainloaders, valloaders, testloader\n", "\n", @@ -195,8 +171,8 @@ "metadata": {}, "outputs": [], "source": [ - "images, labels = next(iter(trainloaders[0]))\n", - "\n", + "batch = next(iter(trainloaders[0]))\n", + "images, labels = batch[\"img\"], batch[\"label\"]\n", "# Reshape and convert images to a NumPy array\n", "# matplotlib requires images with the shape (height, width, 3)\n", "images = images.permute(0, 2, 3, 1).numpy()\n", @@ -209,7 +185,7 @@ "# Loop over the images and plot them\n", "for i, ax in enumerate(axs.flat):\n", " ax.imshow(images[i])\n", - " ax.set_title(CLASSES[labels[i]])\n", + " ax.set_title(trainloaders[0].dataset.features[\"label\"].int2str([labels[i]])[0])\n", " ax.axis(\"off\")\n", "\n", "# Show the plot\n", @@ -294,8 +270,8 @@ " net.train()\n", " for epoch in range(epochs):\n", " correct, total, epoch_loss = 0, 0, 0.0\n", - " for images, labels in trainloader:\n", - " images, labels = images.to(DEVICE), labels.to(DEVICE)\n", + " for batch in trainloader:\n", + " images, labels = batch[\"img\"].to(DEVICE), batch[\"label\"].to(DEVICE)\n", " optimizer.zero_grad()\n", " outputs = net(images)\n", " loss = criterion(outputs, labels)\n", @@ -317,8 +293,8 @@ " correct, total, loss = 0, 0, 0.0\n", " net.eval()\n", " with torch.no_grad():\n", - " for images, labels in testloader:\n", - " images, labels = images.to(DEVICE), labels.to(DEVICE)\n", + " for batch in testloader:\n", + " images, labels = batch[\"img\"].to(DEVICE), batch[\"label\"].to(DEVICE)\n", " outputs = net(images)\n", " loss += criterion(outputs, labels).item()\n", " _, predicted = torch.max(outputs.data, 1)\n", @@ -477,7 +453,7 @@ " valloader = valloaders[int(cid)]\n", "\n", " # Create a single Flower client representing a single organization\n", - " return FlowerClient(net, trainloader, valloader)" + " return FlowerClient(net, trainloader, valloader).to_client()" ] }, { @@ -508,10 +484,14 @@ " min_available_clients=10, # Wait until all 10 clients are available\n", ")\n", "\n", - "# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)\n", - "client_resources = None\n", + "# Specify the resources each of your clients need. By default, each\n", + "# client will be allocated 1x CPU and 0x CPUs\n", + "client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n", "if DEVICE.type == \"cuda\":\n", - " client_resources = {\"num_gpus\": 1}\n", + " # here we are asigning an entire GPU for each client.\n", + " client_resources = {\"num_cpus\": 1, \"num_gpus\": 1.0}\n", + " # Refer to our documentation for more details about Flower Simulations\n", + " # and how to setup these `client_resources`.\n", "\n", "# Start simulation\n", "fl.simulation.start_simulation(\n", @@ -629,7 +609,7 @@ "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them." + "The [Flower Federated Learning Tutorial - Part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n" ] } ], @@ -640,7 +620,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "flower-3.7.12", + "display_name": "flwr", "language": "python", "name": "python3" } diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 05b997ff4133..8e5c3adff5e6 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -1,7 +1,10 @@ +from datetime import datetime + import flwr as fl import numpy as np SUBSET_SIZE = 1000 +STATE_VAR = 'timestamp' model_params = np.array([1]) @@ -12,16 +15,29 @@ class FlowerClient(fl.client.NumPyClient): def get_parameters(self, config): return model_params + def _record_timestamp_to_state(self): + """Record timestamp to client's state.""" + t_stamp = datetime.now().timestamp() + if STATE_VAR in self.state.state: + self.state.state[STATE_VAR] += f",{t_stamp}" + else: + self.state.state[STATE_VAR] = str(t_stamp) + + def _retrieve_timestamp_from_state(self): + return self.state.state[STATE_VAR] + def fit(self, parameters, config): model_params = parameters model_params = [param * (objective/np.mean(param)) for param in model_params] - return model_params, 1, {} + self._record_timestamp_to_state() + return model_params, 1, {STATE_VAR: self._retrieve_timestamp_from_state()} def evaluate(self, parameters, config): model_params = parameters loss = min(np.abs(1 - np.mean(model_params)/objective), 1) accuracy = 1 - loss - return loss, 1, {"accuracy": accuracy} + self._record_timestamp_to_state() + return loss, 1, {"accuracy": accuracy, STATE_VAR: self._retrieve_timestamp_from_state()} def client_fn(cid): return FlowerClient().to_client() diff --git a/e2e/bare/simulation.py b/e2e/bare/simulation.py index 5f0e5334bd08..3a90d90a0ae0 100644 --- a/e2e/bare/simulation.py +++ b/e2e/bare/simulation.py @@ -1,11 +1,41 @@ +from typing import List, Tuple +import numpy as np + import flwr as fl +from flwr.common import Metrics from client import client_fn +STATE_VAR = 'timestamp' + + +# Define metric aggregation function +def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics: + """Ensure that timestamps are monotonically increasing.""" + states = [] + for _, m in metrics: + # split string and covert timestamps to float + states.append([float(tt) for tt in m[STATE_VAR].split(',')]) + + for client_state in states: + if len(client_state) == 1: + continue + deltas = np.diff(client_state) + assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}" + + return {STATE_VAR: states} + + +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics) hist = fl.simulation.start_simulation( client_fn=client_fn, num_clients=2, config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, ) assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98 + +# The checks in record_state_metrics don't do anythinng if client's state has a single entry +state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1] +assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds" diff --git a/e2e/jax/pyproject.toml b/e2e/jax/pyproject.toml index fde3c32608ca..3db32ea855eb 100644 --- a/e2e/jax/pyproject.toml +++ b/e2e/jax/pyproject.toml @@ -7,8 +7,8 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = "^3.8" flwr = { path = "../../", develop = true, extras = ["simulation"] } -jax = "^0.4.0" -jaxlib = "^0.4.0" +jax = "==0.4.13" +jaxlib = "==0.4.13" scikit-learn = "^1.1.1" numpy = "^1.21.4" diff --git a/e2e/pytorch/client.py b/e2e/pytorch/client.py index f4e7e0300a06..d180ad5d4eca 100644 --- a/e2e/pytorch/client.py +++ b/e2e/pytorch/client.py @@ -1,5 +1,6 @@ import warnings from collections import OrderedDict +from datetime import datetime import torch import torch.nn as nn @@ -18,6 +19,7 @@ warnings.filterwarnings("ignore", category=UserWarning) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") SUBSET_SIZE = 1000 +STATE_VAR = 'timestamp' class Net(nn.Module): @@ -89,16 +91,29 @@ def load_data(): class FlowerClient(fl.client.NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] + + def _record_timestamp_to_state(self): + """Record timestamp to client's state.""" + t_stamp = datetime.now().timestamp() + if STATE_VAR in self.state.state: + self.state.state[STATE_VAR] += f",{t_stamp}" + else: + self.state.state[STATE_VAR] = str(t_stamp) + + def _retrieve_timestamp_from_state(self): + return self.state.state[STATE_VAR] def fit(self, parameters, config): set_parameters(net, parameters) train(net, trainloader, epochs=1) - return self.get_parameters(config={}), len(trainloader.dataset), {} + self._record_timestamp_to_state() + return self.get_parameters(config={}), len(trainloader.dataset), {STATE_VAR: self._retrieve_timestamp_from_state()} def evaluate(self, parameters, config): set_parameters(net, parameters) loss, accuracy = test(net, testloader) - return loss, len(testloader.dataset), {"accuracy": accuracy} + self._record_timestamp_to_state() + return loss, len(testloader.dataset), {"accuracy": accuracy, STATE_VAR: self._retrieve_timestamp_from_state()} def set_parameters(model, parameters): params_dict = zip(model.state_dict().keys(), parameters) diff --git a/e2e/pytorch/simulation.py b/e2e/pytorch/simulation.py index 5f0e5334bd08..a4c8d4642be2 100644 --- a/e2e/pytorch/simulation.py +++ b/e2e/pytorch/simulation.py @@ -1,11 +1,42 @@ +from typing import List, Tuple +import numpy as np + import flwr as fl +from flwr.common import Metrics + from client import client_fn +STATE_VAR = 'timestamp' + + +# Define metric aggregation function +def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics: + """Ensure that timestamps are monotonically increasing.""" + states = [] + for _, m in metrics: + # split string and covert timestamps to float + states.append([float(tt) for tt in m[STATE_VAR].split(',')]) + + for client_state in states: + if len(client_state) == 1: + continue + deltas = np.diff(client_state) + assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}" + + return {STATE_VAR: states} + + +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics) hist = fl.simulation.start_simulation( client_fn=client_fn, num_clients=2, config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, ) assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98 + +# The checks in record_state_metrics don't do anythinng if client's state has a single entry +state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1] +assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds" diff --git a/e2e/server.py b/e2e/server.py index 758855caa955..9abfbb27fafc 100644 --- a/e2e/server.py +++ b/e2e/server.py @@ -1,8 +1,47 @@ +from typing import List, Tuple +import numpy as np + + import flwr as fl +from flwr.common import Metrics +STATE_VAR = 'timestamp' + + +# Define metric aggregation function +def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics: + """Ensure that timestamps are monotonically increasing.""" + if not metrics: + return {} + + if STATE_VAR not in metrics[0][1]: + # Do nothing if keyword is not present + return {} + + states = [] + for _, m in metrics: + # split string and covert timestamps to float + states.append([float(tt) for tt in m[STATE_VAR].split(',')]) + + for client_state in states: + if len(client_state) == 1: + continue + deltas = np.diff(client_state) + assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}" + + return {STATE_VAR: states} + + +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics) hist = fl.server.start_server( server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, ) assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98 + +if STATE_VAR in hist.metrics_distributed: + # The checks in record_state_metrics don't do anythinng if client's state has a single entry + state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1] + assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds" diff --git a/e2e/strategies/pyproject.toml b/e2e/strategies/pyproject.toml index 611325ce721b..edfb16de59a6 100644 --- a/e2e/strategies/pyproject.toml +++ b/e2e/strategies/pyproject.toml @@ -12,3 +12,4 @@ authors = ["The Flower Authors "] python = ">=3.8,<3.11" flwr = { path = "../../", develop = true, extras = ["simulation"] } tensorflow-cpu = "^2.9.1, !=2.11.1" +tensorflow-io-gcs-filesystem = "<0.35.0" diff --git a/e2e/tabnet/pyproject.toml b/e2e/tabnet/pyproject.toml index 644f422d6762..7b47ffeb1470 100644 --- a/e2e/tabnet/pyproject.toml +++ b/e2e/tabnet/pyproject.toml @@ -14,4 +14,5 @@ flwr = { path = "../../", develop = true, extras = ["simulation"] } tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} tensorflow_datasets = "4.9.2" +tensorflow-io-gcs-filesystem = "<0.35.0" tabnet = "0.1.6" diff --git a/e2e/tensorflow/pyproject.toml b/e2e/tensorflow/pyproject.toml index c66ffc30fdf0..467e69a026b3 100644 --- a/e2e/tensorflow/pyproject.toml +++ b/e2e/tensorflow/pyproject.toml @@ -12,3 +12,4 @@ authors = ["The Flower Authors "] python = ">=3.8,<3.11" flwr = { path = "../../", develop = true, extras = ["simulation"] } tensorflow-cpu = "^2.9.1, !=2.11.1" +tensorflow-io-gcs-filesystem = "<0.35.0" diff --git a/e2e/test_driver.sh b/e2e/test_driver.sh index ca54dbf4852f..32314bd22533 100755 --- a/e2e/test_driver.sh +++ b/e2e/test_driver.sh @@ -16,10 +16,10 @@ esac timeout 2m flower-server $server_arg & sleep 3 -timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 & +timeout 2m flower-client client:flower $client_arg --server 127.0.0.1:9092 & sleep 3 -timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 & +timeout 2m flower-client client:flower $client_arg --server 127.0.0.1:9092 & sleep 3 timeout 2m python driver.py & diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index db0245e41453..2527e8e4a820 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -1,6 +1,6 @@ # Advanced Flower Example (PyTorch) -This example demonstrates an advanced federated learning setup using Flower with PyTorch. It differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) - Each client holds a local dataset of 5000 training examples and 1000 test examples (note that using the `run.sh` script will only select 10 data samples by default, as the `--toy` argument is set). @@ -59,12 +59,13 @@ pip install -r requirements.txt The included `run.sh` will start the Flower server (using `server.py`), sleep for 2 seconds to ensure that the server is up, and then start 10 Flower clients (using `client.py`) with only a small subset of the data (in order to run on any machine), -but this can be changed by removing the `--toy True` argument in the script. You can simply start everything in a terminal as follows: +but this can be changed by removing the `--toy` argument in the script. You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +# After activating your environment +./run.sh ``` The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `poetry run python3 server.py` and `poetry run python3 client.py` for as many clients as you want but you have to make sure that each command is ran in a different terminal window (or a different computer on the network). +You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index f9ffb6181fd8..b22cbcd70465 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -6,6 +6,7 @@ import argparse from collections import OrderedDict import warnings +import datasets warnings.filterwarnings("ignore") @@ -13,9 +14,9 @@ class CifarClient(fl.client.NumPyClient): def __init__( self, - trainset: torchvision.datasets, - testset: torchvision.datasets, - device: str, + trainset: datasets.Dataset, + testset: datasets.Dataset, + device: torch.device, validation_split: int = 0.1, ): self.device = device @@ -41,17 +42,14 @@ def fit(self, parameters, config): batch_size: int = config["batch_size"] epochs: int = config["local_epochs"] - n_valset = int(len(self.trainset) * self.validation_split) + train_valid = self.trainset.train_test_split(self.validation_split) + trainset = train_valid["train"] + valset = train_valid["test"] - valset = torch.utils.data.Subset(self.trainset, range(0, n_valset)) - trainset = torch.utils.data.Subset( - self.trainset, range(n_valset, len(self.trainset)) - ) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(valset, batch_size=batch_size) - trainLoader = DataLoader(trainset, batch_size=batch_size, shuffle=True) - valLoader = DataLoader(valset, batch_size=batch_size) - - results = utils.train(model, trainLoader, valLoader, epochs, self.device) + results = utils.train(model, train_loader, val_loader, epochs, self.device) parameters_prime = utils.get_model_params(model) num_examples_train = len(trainset) @@ -73,13 +71,13 @@ def evaluate(self, parameters, config): return float(loss), len(self.testset), {"accuracy": float(accuracy)} -def client_dry_run(device: str = "cpu"): +def client_dry_run(device: torch.device = "cpu"): """Weak tests to check whether all client methods are working as expected.""" model = utils.load_efficientnet(classes=10) trainset, testset = utils.load_partition(0) - trainset = torch.utils.data.Subset(trainset, range(10)) - testset = torch.utils.data.Subset(testset, range(10)) + trainset = trainset.select(range(10)) + testset = testset.select(range(10)) client = CifarClient(trainset, testset, device) client.fit( utils.get_model_params(model), @@ -102,7 +100,7 @@ def main() -> None: help="Do a dry-run to check the client", ) parser.add_argument( - "--partition", + "--client-id", type=int, default=0, choices=range(0, 10), @@ -112,9 +110,7 @@ def main() -> None: ) parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action='store_true', help="Set to true to quicky run the client using only 10 datasamples. \ Useful for testing purposes. Default: False", ) @@ -136,12 +132,11 @@ def main() -> None: client_dry_run(device) else: # Load a subset of CIFAR-10 to simulate the local data partition - trainset, testset = utils.load_partition(args.partition) + trainset, testset = utils.load_partition(args.client_id) if args.toy: - trainset = torch.utils.data.Subset(trainset, range(10)) - testset = torch.utils.data.Subset(testset, range(10)) - + trainset = trainset.select(range(10)) + testset = testset.select(range(10)) # Start Flower client client = CifarClient(trainset, testset, device) diff --git a/examples/advanced-pytorch/pyproject.toml b/examples/advanced-pytorch/pyproject.toml index a12f3c47de70..89fd5a32a89e 100644 --- a/examples/advanced-pytorch/pyproject.toml +++ b/examples/advanced-pytorch/pyproject.toml @@ -14,6 +14,7 @@ authors = [ [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "1.13.1" torchvision = "0.14.1" validators = "0.18.2" diff --git a/examples/advanced-pytorch/requirements.txt b/examples/advanced-pytorch/requirements.txt index ba7b284df90e..f4d6a0774162 100644 --- a/examples/advanced-pytorch/requirements.txt +++ b/examples/advanced-pytorch/requirements.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 torch==1.13.1 torchvision==0.14.1 validators==0.18.2 diff --git a/examples/advanced-pytorch/run.sh b/examples/advanced-pytorch/run.sh index 212285f504f9..3367e1680535 100755 --- a/examples/advanced-pytorch/run.sh +++ b/examples/advanced-pytorch/run.sh @@ -2,20 +2,17 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ -# Download the CIFAR-10 dataset -python -c "from torchvision.datasets import CIFAR10; CIFAR10('./dataset', download=True)" - # Download the EfficientNetB0 model python -c "import torch; torch.hub.load( \ 'NVIDIA/DeepLearningExamples:torchhub', \ 'nvidia_efficientnet_b0', pretrained=True)" -python server.py & -sleep 3 # Sleep for 3s to give the server enough time to start +python server.py --toy & +sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset for i in `seq 0 9`; do echo "Starting client $i" - python client.py --partition=${i} --toy True & + python client.py --client-id=${i} --toy & done # Enable CTRL+C to stop all background processes diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index 8343e62da69f..fda49b71a311 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -10,6 +10,8 @@ import warnings +from flwr_datasets import FederatedDataset + warnings.filterwarnings("ignore") @@ -39,18 +41,13 @@ def evaluate_config(server_round: int): def get_evaluate_fn(model: torch.nn.Module, toy: bool): """Return an evaluation function for server-side evaluation.""" - # Load data and model here to avoid the overhead of doing it in `evaluate` itself - trainset, _, _ = utils.load_data() - - n_train = len(trainset) + # Load data here to avoid the overhead of doing it in `evaluate` itself + centralized_data = utils.load_centralized_data() if toy: # use only 10 samples as validation set - valset = torch.utils.data.Subset(trainset, range(n_train - 10, n_train)) - else: - # Use the last 5k training examples as a validation set - valset = torch.utils.data.Subset(trainset, range(n_train - 5000, n_train)) + centralized_data = centralized_data.select(range(10)) - valLoader = DataLoader(valset, batch_size=16) + val_loader = DataLoader(centralized_data, batch_size=16) # The `evaluate` function will be called after every round def evaluate( @@ -63,7 +60,7 @@ def evaluate( state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) model.load_state_dict(state_dict, strict=True) - loss, accuracy = utils.test(model, valLoader) + loss, accuracy = utils.test(model, val_loader) return loss, {"accuracy": accuracy} return evaluate @@ -79,9 +76,7 @@ def main(): parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action='store_true', help="Set to true to use only 10 datasamples for validation. \ Useful for testing purposes. Default: False", ) diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 8788ead90dee..6512010b1f23 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -1,49 +1,45 @@ import torch -import torchvision.transforms as transforms -from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop +from torch.utils.data import DataLoader import warnings -warnings.filterwarnings("ignore") - -# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +from flwr_datasets import FederatedDataset +warnings.filterwarnings("ignore") -def load_data(): - """Load CIFAR-10 (training and test set).""" - transform = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - trainset = CIFAR10("./dataset", train=True, download=True, transform=transform) - testset = CIFAR10("./dataset", train=False, download=True, transform=transform) +def load_partition(node_id, toy: bool = False): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + partition = fds.load_partition(node_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + partition_train_test = partition_train_test.with_transform(apply_transforms) + return partition_train_test["train"], partition_train_test["test"] - num_examples = {"trainset": len(trainset), "testset": len(testset)} - return trainset, testset, num_examples +def load_centralized_data(): + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + centralized_data = fds.load_full("test") + centralized_data = centralized_data.with_transform(apply_transforms) + return centralized_data -def load_partition(idx: int): - """Load 1/10th of the training and test data to simulate a partition.""" - assert idx in range(10) - trainset, testset, num_examples = load_data() - n_train = int(num_examples["trainset"] / 10) - n_test = int(num_examples["testset"] / 10) - train_parition = torch.utils.data.Subset( - trainset, range(idx * n_train, (idx + 1) * n_train) - ) - test_parition = torch.utils.data.Subset( - testset, range(idx * n_test, (idx + 1) * n_test) - ) - return (train_parition, test_parition) +def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + pytorch_transforms = Compose([ + Resize(256), + CenterCrop(224), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch -def train(net, trainloader, valloader, epochs, device: str = "cpu"): +def train(net, trainloader, valloader, epochs, + device: torch.device = torch.device("cpu")): """Train the network on the training set.""" print("Starting training...") net.to(device) # move model to GPU if available @@ -53,7 +49,8 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"): ) net.train() for _ in range(epochs): - for images, labels in trainloader: + for batch in trainloader: + images, labels = batch["img"], batch["label"] images, labels = images.to(device), labels.to(device) optimizer.zero_grad() loss = criterion(net(images), labels) @@ -74,7 +71,8 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"): return results -def test(net, testloader, steps: int = None, device: str = "cpu"): +def test(net, testloader, steps: int = None, + device: torch.device = torch.device("cpu")): """Validate the network on the entire test set.""" print("Starting evalutation...") net.to(device) # move model to GPU if available @@ -82,7 +80,8 @@ def test(net, testloader, steps: int = None, device: str = "cpu"): correct, loss = 0, 0.0 net.eval() with torch.no_grad(): - for batch_idx, (images, labels) in enumerate(testloader): + for batch_idx, batch in enumerate(testloader): + images, labels = batch["img"], batch["label"] images, labels = images.to(device), labels.to(device) outputs = net(images) loss += criterion(outputs, labels).item() @@ -109,12 +108,14 @@ def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int = entrypoint: EfficientNet model to download. For supported entrypoints, please refer https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/ - classes: Number of classes in final classifying layer. Leave as None to get the downloaded + classes: Number of classes in final classifying layer. Leave as None to get + the downloaded model untouched. Returns: EfficientNet Model - Note: One alternative implementation can be found at https://github.com/lukemelas/EfficientNet-PyTorch + Note: One alternative implementation can be found at + https://github.com/lukemelas/EfficientNet-PyTorch """ efficientnet = torch.hub.load( "NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True diff --git a/examples/advanced-tensorflow/README.md b/examples/advanced-tensorflow/README.md index 31bf5edb64c6..b21c0d2545ca 100644 --- a/examples/advanced-tensorflow/README.md +++ b/examples/advanced-tensorflow/README.md @@ -1,9 +1,9 @@ # Advanced Flower Example (TensorFlow/Keras) -This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. It differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) -- Each client holds a local dataset of 5000 training examples and 1000 test examples (note that by default only a small subset of this data is used when running the `run.sh` script) +- Each client holds a local dataset of 1/10 of the train datasets and 80% is training examples and 20% as test examples (note that by default only a small subset of this data is used when running the `run.sh` script) - Server-side model evaluation after parameter aggregation - Hyperparameter schedule using config functions - Custom return values @@ -57,10 +57,11 @@ pip install -r requirements.txt ## Run Federated Learning with TensorFlow/Keras and Flower -The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 2 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows: +The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 10 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +# Once you have activated your environment +./run.sh ``` The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). diff --git a/examples/advanced-tensorflow/client.py b/examples/advanced-tensorflow/client.py index 1c0b61575635..033f20b1b027 100644 --- a/examples/advanced-tensorflow/client.py +++ b/examples/advanced-tensorflow/client.py @@ -6,6 +6,8 @@ import flwr as fl +from flwr_datasets import FederatedDataset + # Make TensorFlow logs less verbose os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -74,7 +76,7 @@ def main() -> None: # Parse command line argument `partition` parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--partition", + "--client-id", type=int, default=0, choices=range(0, 10), @@ -84,9 +86,7 @@ def main() -> None: ) parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action='store_true', help="Set to true to quicky run the client using only 10 datasamples. " "Useful for testing purposes. Default: False", ) @@ -99,7 +99,7 @@ def main() -> None: model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) # Load a subset of CIFAR-10 to simulate the local data partition - (x_train, y_train), (x_test, y_test) = load_partition(args.partition) + x_train, y_train, x_test, y_test = load_partition(args.client_id) if args.toy: x_train, y_train = x_train[:10], y_train[:10] @@ -117,15 +117,16 @@ def main() -> None: def load_partition(idx: int): """Load 1/10th of the training and test data to simulate a partition.""" - assert idx in range(10) - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() - return ( - x_train[idx * 5000 : (idx + 1) * 5000], - y_train[idx * 5000 : (idx + 1) * 5000], - ), ( - x_test[idx * 1000 : (idx + 1) * 1000], - y_test[idx * 1000 : (idx + 1) * 1000], - ) + # Download and partition dataset + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + partition = fds.load_partition(idx) + partition.set_format("numpy") + + # Divide data on each node: 80% train, 20% test + partition = partition.train_test_split(test_size=0.2) + x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] + x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] + return x_train, y_train, x_test, y_test if __name__ == "__main__": diff --git a/examples/advanced-tensorflow/pyproject.toml b/examples/advanced-tensorflow/pyproject.toml index 293ba64b3f43..2f16d8a15584 100644 --- a/examples/advanced-tensorflow/pyproject.toml +++ b/examples/advanced-tensorflow/pyproject.toml @@ -11,5 +11,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} diff --git a/examples/advanced-tensorflow/requirements.txt b/examples/advanced-tensorflow/requirements.txt index 7a70c46a8128..0cb5fe8c07af 100644 --- a/examples/advanced-tensorflow/requirements.txt +++ b/examples/advanced-tensorflow/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" diff --git a/examples/advanced-tensorflow/run.sh b/examples/advanced-tensorflow/run.sh index 8ddb6a252b52..4acef1371571 100755 --- a/examples/advanced-tensorflow/run.sh +++ b/examples/advanced-tensorflow/run.sh @@ -5,14 +5,11 @@ echo "Starting server" python server.py & -sleep 3 # Sleep for 3s to give the server enough time to start +sleep 10 # Sleep for 10s to give the server enough time to start and download the dataset -# Ensure that the Keras dataset used in client.py is already cached. -python -c "import tensorflow as tf; tf.keras.datasets.cifar10.load_data()" - -for i in `seq 0 9`; do +for i in $(seq 0 9); do echo "Starting client $i" - python client.py --partition=${i} --toy True & + python client.py --client-id=${i} --toy & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/advanced-tensorflow/server.py b/examples/advanced-tensorflow/server.py index e1eb3d4fd8f7..26dde312bee5 100644 --- a/examples/advanced-tensorflow/server.py +++ b/examples/advanced-tensorflow/server.py @@ -4,6 +4,8 @@ import flwr as fl import tensorflow as tf +from flwr_datasets import FederatedDataset + def main() -> None: # Load and compile model for @@ -43,11 +45,11 @@ def main() -> None: def get_evaluate_fn(model): """Return an evaluation function for server-side evaluation.""" - # Load data and model here to avoid the overhead of doing it in `evaluate` itself - (x_train, y_train), _ = tf.keras.datasets.cifar10.load_data() - - # Use the last 5k training examples as a validation set - x_val, y_val = x_train[45000:50000], y_train[45000:50000] + # Load data here to avoid the overhead of doing it in `evaluate` itself + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + test = fds.load_full("test") + test.set_format("numpy") + x_test, y_test = test["img"] / 255.0, test["label"] # The `evaluate` function will be called after every round def evaluate( @@ -56,7 +58,7 @@ def evaluate( config: Dict[str, fl.common.Scalar], ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: model.set_weights(parameters) # Update model with the latest parameters - loss, accuracy = model.evaluate(x_val, y_val) + loss, accuracy = model.evaluate(x_test, y_test) return loss, {"accuracy": accuracy} return evaluate diff --git a/examples/android/README.md b/examples/android/README.md index 7931aa96b0c5..f9f2bb93b8dc 100644 --- a/examples/android/README.md +++ b/examples/android/README.md @@ -54,4 +54,4 @@ poetry run ./run.sh Download and install the `flwr_android_client.apk` on each Android device/emulator. The server currently expects a minimum of 4 Android clients, but it can be changed in the `server.py`. -When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Load Dataset`. This will load the local CIFAR10 dataset in memory. Then press `Setup Connection Channel` which will establish connection with the server. Finally, press `Train Federated!` which will start the federated training. +When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Start`. This will load the local CIFAR10 dataset in memory, establish connection with the server, and start the federated training. To abort the federated learning process, press `Stop`. You can clear and refresh the log messages by pressing `Clear` and `Refresh` buttons respectively. diff --git a/examples/doc/source/conf.py b/examples/doc/source/conf.py index 01cbb48c1587..3d629c39c7ea 100644 --- a/examples/doc/source/conf.py +++ b/examples/doc/source/conf.py @@ -22,8 +22,11 @@ # -- Project information ----------------------------------------------------- +import datetime + + project = "Flower" -copyright = "2022 Flower Labs GmbH" +copyright = f"{datetime.date.today().year} Flower Labs GmbH" author = "The Flower Authors" # The full version, including alpha/beta/rc tags diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py b/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py index f5c76ab6dc99..f8124b9353f7 100644 --- a/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py +++ b/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py @@ -24,7 +24,7 @@ def main(cfg: DictConfig): save_path = HydraConfig.get().runtime.output_dir ## 2. Prepare your dataset - # When simulating FL workloads we have a lot of freedom on how the FL clients behave, + # When simulating FL runs we have a lot of freedom on how the FL clients behave, # what data they have, how much data, etc. This is not possible in real FL settings. # In simulation you'd often encounter two types of dataset: # * naturally partitioned, that come pre-partitioned by user id (e.g. FEMNIST, @@ -91,7 +91,7 @@ def main(cfg: DictConfig): "num_gpus": 0.0, }, # (optional) controls the degree of parallelism of your simulation. # Lower resources per client allow for more clients to run concurrently - # (but need to be set taking into account the compute/memory footprint of your workload) + # (but need to be set taking into account the compute/memory footprint of your run) # `num_cpus` is an absolute number (integer) indicating the number of threads a client should be allocated # `num_gpus` is a ratio indicating the portion of gpu memory that a client needs. ) diff --git a/examples/mt-pytorch-callable/README.md b/examples/mt-pytorch-callable/README.md index 65ef000c26f2..120e28098344 100644 --- a/examples/mt-pytorch-callable/README.md +++ b/examples/mt-pytorch-callable/README.md @@ -33,13 +33,13 @@ flower-server --insecure In a new terminal window, start the first long-running Flower client: ```bash -flower-client --callable client:flower +flower-client --insecure client:flower ``` In yet another new terminal window, start the second long-running Flower client: ```bash -flower-client --callable client:flower +flower-client --insecure client:flower ``` ## Start the Driver script diff --git a/examples/mt-pytorch-callable/client.py b/examples/mt-pytorch-callable/client.py index 6f9747784ae0..4195a714ca89 100644 --- a/examples/mt-pytorch-callable/client.py +++ b/examples/mt-pytorch-callable/client.py @@ -108,7 +108,7 @@ def client_fn(cid: str): return FlowerClient().to_client() -# To run this: `flower-client --callable client:flower` +# To run this: `flower-client client:flower` flower = fl.flower.Flower( client_fn=client_fn, ) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index fed760f021af..ad4d5e1caabe 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -54,13 +54,13 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK driver.connect() -create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload( - req=driver_pb2.CreateWorkloadRequest() +create_run_res: driver_pb2.CreateRunResponse = driver.create_run( + req=driver_pb2.CreateRunRequest() ) # -------------------------------------------------------------------------- Driver SDK -workload_id = create_workload_res.workload_id -print(f"Created workload id {workload_id}") +run_id = create_run_res.run_id +print(f"Created run id {run_id}") history = History() for server_round in range(num_rounds): @@ -93,7 +93,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # loop and wait until enough client nodes are available. while True: # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id) + get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id) # ---------------------------------------------------------------------- Driver SDK get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( @@ -125,7 +125,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: new_task_ins = task_pb2.TaskIns( task_id="", # Do not set, will be created and set by the DriverAPI group_id="", - workload_id=workload_id, + run_id=run_id, task=task_pb2.Task( producer=node_pb2.Node( node_id=0, diff --git a/examples/mxnet-from-centralized-to-federated/README.md b/examples/mxnet-from-centralized-to-federated/README.md index 839d3b16a1cf..2c3f240d8978 100644 --- a/examples/mxnet-from-centralized-to-federated/README.md +++ b/examples/mxnet-from-centralized-to-federated/README.md @@ -1,5 +1,7 @@ # MXNet: From Centralized To Federated +> Note the MXNet project has ended, and is now in [Attic](https://attic.apache.org/projects/mxnet.html). The MXNet GitHub has also [been archived](https://github.com/apache/mxnet). As a result, this example won't be receiving more updates. Using MXNet is no longer recommnended. + This example demonstrates how an already existing centralized MXNet-based machine learning project can be federated with Flower. This introductory example for Flower uses MXNet, but you're not required to be a MXNet expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing MXNet project. diff --git a/examples/mxnet-from-centralized-to-federated/pyproject.toml b/examples/mxnet-from-centralized-to-federated/pyproject.toml index a0d31f76ebdd..952683eb90f6 100644 --- a/examples/mxnet-from-centralized-to-federated/pyproject.toml +++ b/examples/mxnet-from-centralized-to-federated/pyproject.toml @@ -10,7 +10,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.0,<2.0" -# flwr = { path = "../../", develop = true } # Development -mxnet = "1.6.0" +flwr = "1.6.0" +mxnet = "1.9.1" numpy = "1.23.1" diff --git a/examples/mxnet-from-centralized-to-federated/requirements.txt b/examples/mxnet-from-centralized-to-federated/requirements.txt index 73060e27c70c..8dd6f7150dfd 100644 --- a/examples/mxnet-from-centralized-to-federated/requirements.txt +++ b/examples/mxnet-from-centralized-to-federated/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.0,<2.0 -mxnet==1.6.0 +flwr==1.6.0 +mxnet==1.9.1 numpy==1.23.1 diff --git a/examples/pytorch-from-centralized-to-federated/README.md b/examples/pytorch-from-centralized-to-federated/README.md index 40f7f40e5adc..fccb14158ecd 100644 --- a/examples/pytorch-from-centralized-to-federated/README.md +++ b/examples/pytorch-from-centralized-to-federated/README.md @@ -2,7 +2,7 @@ This example demonstrates how an already existing centralized PyTorch-based machine learning project can be federated with Flower. -This introductory example for Flower uses PyTorch, but you're not required to be a PyTorch expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on existing machine learning projects. +This introductory example for Flower uses PyTorch, but you're not required to be a PyTorch expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on existing machine learning projects. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup diff --git a/examples/pytorch-from-centralized-to-federated/__init__.py b/examples/pytorch-from-centralized-to-federated/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/examples/pytorch-from-centralized-to-federated/cifar.py b/examples/pytorch-from-centralized-to-federated/cifar.py index 3c1d67d2f445..e8f3ec3fd724 100644 --- a/examples/pytorch-from-centralized-to-federated/cifar.py +++ b/examples/pytorch-from-centralized-to-federated/cifar.py @@ -6,22 +6,20 @@ https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html """ - # mypy: ignore-errors # pylint: disable=W0223 -from typing import Tuple, Dict +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F -import torchvision -import torchvision.transforms as transforms from torch import Tensor -from torchvision.datasets import CIFAR10 +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, ToTensor, Normalize -DATA_ROOT = "./dataset" +from flwr_datasets import FederatedDataset # pylint: disable=unsubscriptable-object @@ -53,19 +51,25 @@ def forward(self, x: Tensor) -> Tensor: return x -def load_data() -> ( - Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict] -): - """Load CIFAR-10 (training and test set).""" - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] +def load_data(node_id: int): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + partition = fds.load_partition(node_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + pytorch_transforms = Compose( + [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) - trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform) - trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True) - testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform) - testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False) - num_examples = {"trainset": len(trainset), "testset": len(testset)} - return trainloader, testloader, num_examples + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch + + partition_train_test = partition_train_test.with_transform(apply_transforms) + trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) + testloader = DataLoader(partition_train_test["test"], batch_size=32) + return trainloader, testloader def train( @@ -87,7 +91,7 @@ def train( for epoch in range(epochs): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): - images, labels = data[0].to(device), data[1].to(device) + images, labels = data["img"].to(device), data["label"].to(device) # zero the parameter gradients optimizer.zero_grad() @@ -120,7 +124,7 @@ def test( net.eval() with torch.no_grad(): for data in testloader: - images, labels = data[0].to(device), data[1].to(device) + images, labels = data["img"].to(device), data["label"].to(device) outputs = net(images) loss += criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member @@ -133,7 +137,7 @@ def main(): DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Centralized PyTorch training") print("Load data") - trainloader, testloader, _ = load_data() + trainloader, testloader = load_data(0) net = Net().to(DEVICE) net.eval() print("Start training") diff --git a/examples/pytorch-from-centralized-to-federated/client.py b/examples/pytorch-from-centralized-to-federated/client.py index 88678e0569b7..61c7e7f762b3 100644 --- a/examples/pytorch-from-centralized-to-federated/client.py +++ b/examples/pytorch-from-centralized-to-federated/client.py @@ -1,24 +1,22 @@ """Flower client example using PyTorch for CIFAR-10 image classification.""" - - -import os -import sys -import timeit +import argparse from collections import OrderedDict from typing import Dict, List, Tuple -import flwr as fl import numpy as np import torch -import torchvision +from datasets.utils.logging import disable_progress_bar +from torch.utils.data import DataLoader import cifar +import flwr as fl + +disable_progress_bar() + USE_FEDBN: bool = True -# pylint: disable=no-member -DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -# pylint: enable=no-member +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Flower Client @@ -28,19 +26,18 @@ class CifarClient(fl.client.NumPyClient): def __init__( self, model: cifar.Net, - trainloader: torch.utils.data.DataLoader, - testloader: torch.utils.data.DataLoader, - num_examples: Dict, + trainloader: DataLoader, + testloader: DataLoader, ) -> None: self.model = model self.trainloader = trainloader self.testloader = testloader - self.num_examples = num_examples def get_parameters(self, config: Dict[str, str]) -> List[np.ndarray]: self.model.train() if USE_FEDBN: - # Return model parameters as a list of NumPy ndarrays, excluding parameters of BN layers when using FedBN + # Return model parameters as a list of NumPy ndarrays, excluding + # parameters of BN layers when using FedBN return [ val.cpu().numpy() for name, val in self.model.state_dict().items() @@ -69,7 +66,7 @@ def fit( # Set model parameters, train model, return updated model parameters self.set_parameters(parameters) cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE) - return self.get_parameters(config={}), self.num_examples["trainset"], {} + return self.get_parameters(config={}), len(self.trainloader.dataset), {} def evaluate( self, parameters: List[np.ndarray], config: Dict[str, str] @@ -77,23 +74,26 @@ def evaluate( # Set model parameters, evaluate model on local test dataset, return result self.set_parameters(parameters) loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE) - return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)} + return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)} def main() -> None: """Load data, start CifarClient.""" + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument("--node-id", type=int, required=True, choices=range(0, 10)) + args = parser.parse_args() # Load data - trainloader, testloader, num_examples = cifar.load_data() + trainloader, testloader = cifar.load_data(args.node_id) # Load model model = cifar.Net().to(DEVICE).train() # Perform a single forward pass to properly initialize BatchNorm - _ = model(next(iter(trainloader))[0].to(DEVICE)) + _ = model(next(iter(trainloader))["img"].to(DEVICE)) # Start client - client = CifarClient(model, trainloader, testloader, num_examples) + client = CifarClient(model, trainloader, testloader) fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) diff --git a/examples/pytorch-from-centralized-to-federated/pyproject.toml b/examples/pytorch-from-centralized-to-federated/pyproject.toml index 73999a9e6cd4..6d6f138a0aea 100644 --- a/examples/pytorch-from-centralized-to-federated/pyproject.toml +++ b/examples/pytorch-from-centralized-to-federated/pyproject.toml @@ -11,5 +11,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "1.13.1" torchvision = "0.14.1" diff --git a/examples/pytorch-from-centralized-to-federated/requirements.txt b/examples/pytorch-from-centralized-to-federated/requirements.txt index f3caddbc875e..ba4afad9c288 100644 --- a/examples/pytorch-from-centralized-to-federated/requirements.txt +++ b/examples/pytorch-from-centralized-to-federated/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 torch==1.13.1 torchvision==0.14.1 diff --git a/examples/pytorch-from-centralized-to-federated/run.sh b/examples/pytorch-from-centralized-to-federated/run.sh index c64f362086aa..1ed51dd787ac 100755 --- a/examples/pytorch-from-centralized-to-federated/run.sh +++ b/examples/pytorch-from-centralized-to-federated/run.sh @@ -4,9 +4,9 @@ echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 1`; do +for i in $(seq 0 1); do echo "Starting client $i" - python client.py & + python client.py --node-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/pytorch-from-centralized-to-federated/server.py b/examples/pytorch-from-centralized-to-federated/server.py index 29cbce1884d1..42f34b3a78e9 100644 --- a/examples/pytorch-from-centralized-to-federated/server.py +++ b/examples/pytorch-from-centralized-to-federated/server.py @@ -1,10 +1,28 @@ """Flower server example.""" +from typing import List, Tuple + import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) -if __name__ == "__main__": - fl.server.start_server( - server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=3), - ) +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=10), + strategy=strategy, +) diff --git a/examples/quickstart-huggingface/README.md b/examples/quickstart-huggingface/README.md index c1e3cc4edc06..fd868aa1fcce 100644 --- a/examples/quickstart-huggingface/README.md +++ b/examples/quickstart-huggingface/README.md @@ -1,6 +1,6 @@ # Federated HuggingFace Transformers using Flower and PyTorch -This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for detailed explaination for the transformer pipeline. +This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. Like `quickstart-pytorch`, running this example in itself is also meant to be quite easy. @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py +python3 client.py --node-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py +python3 client.py --node-id 1 ``` You will see that PyTorch is starting a federated training. diff --git a/examples/quickstart-huggingface/client.py b/examples/quickstart-huggingface/client.py index 8717d710ad9c..5fa10b9ca0f2 100644 --- a/examples/quickstart-huggingface/client.py +++ b/examples/quickstart-huggingface/client.py @@ -1,58 +1,48 @@ -from collections import OrderedDict +import argparse import warnings +from collections import OrderedDict import flwr as fl import torch -import numpy as np - -import random -from torch.utils.data import DataLoader - -from datasets import load_dataset from evaluate import load as load_metric - -from transformers import AutoTokenizer, DataCollatorWithPadding +from torch.optim import AdamW +from torch.utils.data import DataLoader from transformers import AutoModelForSequenceClassification -from transformers import AdamW +from transformers import AutoTokenizer, DataCollatorWithPadding + +from flwr_datasets import FederatedDataset warnings.filterwarnings("ignore", category=UserWarning) DEVICE = torch.device("cpu") CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint -def load_data(): +def load_data(node_id): """Load IMDB data (training and eval)""" - raw_datasets = load_dataset("imdb") - raw_datasets = raw_datasets.shuffle(seed=42) - - # remove unnecessary data split - del raw_datasets["unsupervised"] + fds = FederatedDataset(dataset="imdb", partitioners={"train": 1_000}) + partition = fds.load_partition(node_id) + # Divide data: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) def tokenize_function(examples): return tokenizer(examples["text"], truncation=True) - # random 100 samples - population = random.sample(range(len(raw_datasets["train"])), 100) - - tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) - tokenized_datasets["train"] = tokenized_datasets["train"].select(population) - tokenized_datasets["test"] = tokenized_datasets["test"].select(population) - - tokenized_datasets = tokenized_datasets.remove_columns("text") - tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + partition_train_test = partition_train_test.map(tokenize_function, batched=True) + partition_train_test = partition_train_test.remove_columns("text") + partition_train_test = partition_train_test.rename_column("label", "labels") data_collator = DataCollatorWithPadding(tokenizer=tokenizer) trainloader = DataLoader( - tokenized_datasets["train"], + partition_train_test["train"], shuffle=True, batch_size=32, collate_fn=data_collator, ) testloader = DataLoader( - tokenized_datasets["test"], batch_size=32, collate_fn=data_collator + partition_train_test["test"], batch_size=32, collate_fn=data_collator ) return trainloader, testloader @@ -88,12 +78,12 @@ def test(net, testloader): return loss, accuracy -def main(): +def main(node_id): net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 ).to(DEVICE) - trainloader, testloader = load_data() + trainloader, testloader = load_data(node_id) # Flower client class IMDBClient(fl.client.NumPyClient): @@ -122,4 +112,14 @@ def evaluate(self, parameters, config): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--node-id", + choices=list(range(1_000)), + required=True, + type=int, + help="Partition of the dataset divided into 1,000 iid partitions created " + "artificially.", + ) + node_id = parser.parse_args().node_id + main(node_id) diff --git a/examples/quickstart-huggingface/pyproject.toml b/examples/quickstart-huggingface/pyproject.toml index eb9687c5152c..50ba0b37f8d2 100644 --- a/examples/quickstart-huggingface/pyproject.toml +++ b/examples/quickstart-huggingface/pyproject.toml @@ -14,6 +14,7 @@ authors = [ [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = ">=0.0.2,<1.0.0" torch = ">=1.13.1,<2.0" transformers = ">=4.30.0,<5.0" evaluate = ">=0.4.0,<1.0" diff --git a/examples/quickstart-huggingface/requirements.txt b/examples/quickstart-huggingface/requirements.txt index aeb2d13fc4a4..3cd5735625ba 100644 --- a/examples/quickstart-huggingface/requirements.txt +++ b/examples/quickstart-huggingface/requirements.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 +flwr-datasets>=0.0.2, <1.0.0 torch>=1.13.1, <2.0 transformers>=4.30.0, <5.0 evaluate>=0.4.0, <1.0 diff --git a/examples/quickstart-huggingface/run.sh b/examples/quickstart-huggingface/run.sh index c64f362086aa..e722a24a21a9 100755 --- a/examples/quickstart-huggingface/run.sh +++ b/examples/quickstart-huggingface/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --node-id ${i}& done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-mlx/README.md b/examples/quickstart-mlx/README.md new file mode 100644 index 000000000000..d94a87a014f7 --- /dev/null +++ b/examples/quickstart-mlx/README.md @@ -0,0 +1,351 @@ +# Flower Example using MLX + +This introductory example to Flower uses [MLX](https://ml-explore.github.io/mlx/build/html/index.html), but deep knowledge of MLX is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. + +[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is a NumPy-like array framework designed for efficient and flexible machine learning on Apple silicon. + +In this example, we will train a simple 2 layers MLP on MNIST data (handwritten digits recognition). + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/quickstart-mlx . && rm -rf _tmp && cd quickstart-mlx +``` + +This will create a new directory called `quickstart-mlx` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- run.sh +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `mlx` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with MLX and Flower + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the +following commands. + +Start a first client in the first terminal: + +```shell +python3 client.py --node-id 0 +``` + +And another one in the second terminal: + +```shell +python3 client.py --node-id 1 +``` + +If you want to utilize your GPU, you can use the `--gpu` argument: + +```shell +python3 client.py --gpu --node-id 2 +``` + +Note that you can start many more clients if you want, but each will have to be in its own terminal. + +You will see that MLX is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-mlx) for a detailed explanation. + +## Explanations + +This example is a federated version of the centralized case that can be found +[here](https://github.com/ml-explore/mlx-examples/tree/main/mnist). + +### The data + +We will use `flwr_datasets` to easily download and partition the `MNIST` dataset: + +```python +fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) +partition = fds.load_partition(node_id = args.node_id) +partition_splits = partition.train_test_split(test_size=0.2) + +partition_splits['train'].set_format("numpy") +partition_splits['test'].set_format("numpy") + +train_partition = partition_splits["train"].map( + lambda img: { + "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0 + }, + input_columns="image", +) +test_partition = partition_splits["test"].map( + lambda img: { + "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0 + }, + input_columns="image", +) + +data = ( + train_partition["img"], + train_partition["label"].astype(np.uint32), + test_partition["img"], + test_partition["label"].astype(np.uint32), +) + +train_images, train_labels, test_images, test_labels = map(mlx.core.array, data) +``` + +### The model + +We define the model as in the centralized mlx example, it's a simple MLP: + +```python +class MLP(mlx.nn.Module): + """A simple MLP.""" + + def __init__( + self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int + ): + super().__init__() + layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] + self.layers = [ + mlx.nn.Linear(idim, odim) + for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + + def __call__(self, x): + for l in self.layers[:-1]: + x = mlx.core.maximum(l(x), 0.0) + return self.layers[-1](x) + +``` + +We also define some utility functions to test our model and to iterate over batches. + +```python +def loss_fn(model, X, y): + return mlx.core.mean(mlx.nn.losses.cross_entropy(model(X), y)) + + +def eval_fn(model, X, y): + return mlx.core.mean(mlx.core.argmax(model(X), axis=1) == y) + + +def batch_iterate(batch_size, X, y): + perm = mlx.core.array(np.random.permutation(y.size)) + for s in range(0, y.size, batch_size): + ids = perm[s : s + batch_size] + yield X[ids], y[ids] + +``` + +### The client + +The main changes we have to make to use `MLX` with `Flower` will be found in +the `get_parameters` and `set_parameters` functions. Indeed, MLX doesn't +provide an easy way to convert the model parameters into a list of `np.array`s +(the format we need for the serialization of the messages to work). + +The way MLX stores its parameters is as follows: + +``` +{ + "layers": [ + {"weight": mlx.core.array, "bias": mlx.core.array}, + {"weight": mlx.core.array, "bias": mlx.core.array}, + ..., + {"weight": mlx.core.array, "bias": mlx.core.array} + ] +} +``` + +Therefore, to get our list of `np.array`s, we need to extract each array and +convert them into a numpy array: + +```python +def get_parameters(self, config): + layers = self.model.parameters()["layers"] + return [np.array(val) for layer in layers for _, val in layer.items()] +``` + +For the `set_parameters` function, we perform the reverse operation. We receive +a list of arrays and want to convert them into MLX parameters. Therefore, we +iterate through pairs of parameters and assign them to the `weight` and `bias` +keys of each layer dict: + +```python +def set_parameters(self, parameters): + new_params = {} + new_params["layers"] = [ + {"weight": mlx.core.array(parameters[i]), "bias": mlx.core.array(parameters[i + 1])} + for i in range(0, len(parameters), 2) + ] + self.model.update(new_params) +``` + +The rest of the functions are directly inspired by the centralized case: + +```python +def fit(self, parameters, config): + self.set_parameters(parameters) + for _ in range(self.num_epochs): + for X, y in batch_iterate( + self.batch_size, self.train_images, self.train_labels + ): + loss, grads = self.loss_and_grad_fn(self.model, X, y) + self.optimizer.update(self.model, grads) + mlx.core.eval(self.model.parameters(), self.optimizer.state) + return self.get_parameters(config={}), len(self.train_images), {} +``` + +Here, after updating the parameters, we perform the training as in the +centralized case, and return the new parameters. + +And for the `evaluate` function: + +```python +def evaluate(self, parameters, config): + self.set_parameters(parameters) + accuracy = eval_fn(self.model, self.test_images, self.test_labels) + loss = loss_fn(self.model, self.test_images, self.test_labels) + return loss.item(), len(self.test_images), {"accuracy": accuracy.item()} +``` + +We also begin by updating the parameters with the ones sent by the server, and +then we compute the loss and accuracy using the functions defined above. + +Putting everything together we have: + +```python +class FlowerClient(fl.client.NumPyClient): + def __init__( + self, model, optim, loss_and_grad_fn, data, num_epochs, batch_size + ) -> None: + self.model = model + self.optimizer = optim + self.loss_and_grad_fn = loss_and_grad_fn + self.train_images, self.train_labels, self.test_images, self.test_labels = data + self.num_epochs = num_epochs + self.batch_size = batch_size + + def get_parameters(self, config): + layers = self.model.parameters()["layers"] + return [np.array(val) for layer in layers for _, val in layer.items()] + + def set_parameters(self, parameters): + new_params = {} + new_params["layers"] = [ + {"weight": mlx.core.array(parameters[i]), "bias": mlx.core.array(parameters[i + 1])} + for i in range(0, len(parameters), 2) + ] + self.model.update(new_params) + + def fit(self, parameters, config): + self.set_parameters(parameters) + for _ in range(self.num_epochs): + for X, y in batch_iterate( + self.batch_size, self.train_images, self.train_labels + ): + loss, grads = self.loss_and_grad_fn(self.model, X, y) + self.optimizer.update(self.model, grads) + mlx.core.eval(self.model.parameters(), self.optimizer.state) + return self.get_parameters(config={}), len(self.train_images), {} + + def evaluate(self, parameters, config): + self.set_parameters(parameters) + accuracy = eval_fn(self.model, self.test_images, self.test_labels) + loss = loss_fn(self.model, self.test_images, self.test_labels) + return loss.item(), len(self.test_images), {"accuracy": accuracy.item()} +``` + +And as you can see, with only a few lines of code, our client is ready! Before +we can instantiate it, we need to define a few variables: + +```python +num_layers = 2 +hidden_dim = 32 +num_classes = 10 +batch_size = 256 +num_epochs = 1 +learning_rate = 1e-1 + +model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) + +loss_and_grad_fn = mlx.nn.value_and_grad(model, loss_fn) +optimizer = mlx.optimizers.SGD(learning_rate=learning_rate) +``` + +Finally, we can instantiate it by using the `start_client` function: + +```python +# Start Flower client +fl.client.start_client( + server_address="127.0.0.1:8080", + client=FlowerClient( + model, + optimizer, + loss_and_grad_fn, + (train_images, train_labels, test_images, test_labels), + num_epochs, + batch_size, + ).to_client(), +) +``` + +### The server + +On the server side, we don't need to add anything in particular. The +`weighted_average` function is just there to be able to aggregate the results +and have an accuracy at the end. + +```python +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) +``` diff --git a/examples/quickstart-mlx/client.py b/examples/quickstart-mlx/client.py new file mode 100644 index 000000000000..3b506399a5f1 --- /dev/null +++ b/examples/quickstart-mlx/client.py @@ -0,0 +1,152 @@ +import argparse + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np +from flwr_datasets import FederatedDataset + +import flwr as fl + + +class MLP(nn.Module): + """A simple MLP.""" + + def __init__( + self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int + ): + super().__init__() + layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] + self.layers = [ + nn.Linear(idim, odim) + for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + + def __call__(self, x): + for l in self.layers[:-1]: + x = mx.maximum(l(x), 0.0) + return self.layers[-1](x) + + +def loss_fn(model, X, y): + return mx.mean(nn.losses.cross_entropy(model(X), y)) + + +def eval_fn(model, X, y): + return mx.mean(mx.argmax(model(X), axis=1) == y) + + +def batch_iterate(batch_size, X, y): + perm = mx.array(np.random.permutation(y.size)) + for s in range(0, y.size, batch_size): + ids = perm[s : s + batch_size] + yield X[ids], y[ids] + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def __init__( + self, model, optim, loss_and_grad_fn, data, num_epochs, batch_size + ) -> None: + self.model = model + self.optimizer = optim + self.loss_and_grad_fn = loss_and_grad_fn + self.train_images, self.train_labels, self.test_images, self.test_labels = data + self.num_epochs = num_epochs + self.batch_size = batch_size + + def get_parameters(self, config): + layers = self.model.parameters()["layers"] + return [np.array(val) for layer in layers for _, val in layer.items()] + + def set_parameters(self, parameters): + new_params = {} + new_params["layers"] = [ + {"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])} + for i in range(0, len(parameters), 2) + ] + self.model.update(new_params) + + def fit(self, parameters, config): + self.set_parameters(parameters) + for _ in range(self.num_epochs): + for X, y in batch_iterate( + self.batch_size, self.train_images, self.train_labels + ): + loss, grads = self.loss_and_grad_fn(self.model, X, y) + self.optimizer.update(self.model, grads) + mx.eval(self.model.parameters(), self.optimizer.state) + return self.get_parameters(config={}), len(self.train_images), {} + + def evaluate(self, parameters, config): + self.set_parameters(parameters) + accuracy = eval_fn(self.model, self.test_images, self.test_labels) + loss = loss_fn(self.model, self.test_images, self.test_labels) + return loss.item(), len(self.test_images), {"accuracy": accuracy.item()} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") + parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") + parser.add_argument( + "--node-id", + choices=[0, 1, 2], + type=int, + help="Partition of the dataset divided into 3 iid partitions created artificially.", + ) + args = parser.parse_args() + if not args.gpu: + mx.set_default_device(mx.cpu) + + num_layers = 2 + hidden_dim = 32 + num_classes = 10 + batch_size = 256 + num_epochs = 1 + learning_rate = 1e-1 + + fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) + partition = fds.load_partition(node_id=args.node_id) + partition_splits = partition.train_test_split(test_size=0.2) + + partition_splits["train"].set_format("numpy") + partition_splits["test"].set_format("numpy") + + train_partition = partition_splits["train"].map( + lambda img: { + "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0 + }, + input_columns="image", + ) + test_partition = partition_splits["test"].map( + lambda img: { + "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0 + }, + input_columns="image", + ) + + data = ( + train_partition["img"], + train_partition["label"].astype(np.uint32), + test_partition["img"], + test_partition["label"].astype(np.uint32), + ) + + train_images, train_labels, test_images, test_labels = map(mx.array, data) + model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) + + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + optimizer = optim.SGD(learning_rate=learning_rate) + + # Start Flower client + fl.client.start_client( + server_address="127.0.0.1:8080", + client=FlowerClient( + model, + optimizer, + loss_and_grad_fn, + (train_images, train_labels, test_images, test_labels), + num_epochs, + batch_size, + ).to_client(), + ) diff --git a/examples/quickstart-mlx/pyproject.toml b/examples/quickstart-mlx/pyproject.toml new file mode 100644 index 000000000000..deb541c5ba9c --- /dev/null +++ b/examples/quickstart-mlx/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "quickstart-mlx" +version = "0.1.0" +description = "MLX Federated Learning Quickstart with Flower" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +mlx = "==0.0.3" +numpy = "==1.24.4" +flwr-datasets = {extras = ["vision"], version = "^0.0.2"} diff --git a/examples/quickstart-mlx/requirements.txt b/examples/quickstart-mlx/requirements.txt new file mode 100644 index 000000000000..0c3ea45ee188 --- /dev/null +++ b/examples/quickstart-mlx/requirements.txt @@ -0,0 +1,4 @@ +flwr>=1.0, <2.0 +mlx==0.0.3 +numpy==1.24.4 +flwr-datasets["vision"]>=0.0.2, <1.0 diff --git a/examples/quickstart-mlx/run.sh b/examples/quickstart-mlx/run.sh new file mode 100755 index 000000000000..70281049517d --- /dev/null +++ b/examples/quickstart-mlx/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in $(seq 0 1); do + echo "Starting client $i" + python client.py --node-id $i & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/quickstart-mlx/server.py b/examples/quickstart-mlx/server.py new file mode 100644 index 000000000000..fe691a88aba0 --- /dev/null +++ b/examples/quickstart-mlx/server.py @@ -0,0 +1,25 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/quickstart-mxnet/README.md b/examples/quickstart-mxnet/README.md index 930cec5acdfd..37e01ef2707c 100644 --- a/examples/quickstart-mxnet/README.md +++ b/examples/quickstart-mxnet/README.md @@ -1,5 +1,7 @@ # Flower Example using MXNet +> Note the MXNet project has ended, and is now in [Attic](https://attic.apache.org/projects/mxnet.html). The MXNet GitHub has also [been archived](https://github.com/apache/mxnet). As a result, this example won't be receiving more updates. Using MXNet is no longer recommnended. + This example demonstrates how to run a MXNet machine learning project federated with Flower. This introductory example for Flower uses MXNet, but you're not required to be a MXNet expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing MXNet projects. diff --git a/examples/quickstart-mxnet/pyproject.toml b/examples/quickstart-mxnet/pyproject.toml index a0d31f76ebdd..952683eb90f6 100644 --- a/examples/quickstart-mxnet/pyproject.toml +++ b/examples/quickstart-mxnet/pyproject.toml @@ -10,7 +10,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.0,<2.0" -# flwr = { path = "../../", develop = true } # Development -mxnet = "1.6.0" +flwr = "1.6.0" +mxnet = "1.9.1" numpy = "1.23.1" diff --git a/examples/quickstart-mxnet/requirements.txt b/examples/quickstart-mxnet/requirements.txt index 73060e27c70c..8dd6f7150dfd 100644 --- a/examples/quickstart-mxnet/requirements.txt +++ b/examples/quickstart-mxnet/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.0,<2.0 -mxnet==1.6.0 +flwr==1.6.0 +mxnet==1.9.1 numpy==1.23.1 diff --git a/examples/quickstart-pandas/README.md b/examples/quickstart-pandas/README.md index 2defc468c2ef..a25e6ea6ee36 100644 --- a/examples/quickstart-pandas/README.md +++ b/examples/quickstart-pandas/README.md @@ -1,6 +1,7 @@ # Flower Example using Pandas -This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. +This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to +download, partition and preprocess the dataset. Running this example in itself is quite easy. ## Project Setup @@ -69,13 +70,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -$ python3 client.py +$ python3 client.py --node-id 0 ``` Start client 2 in the second terminal: ```shell -$ python3 client.py +$ python3 client.py --node-id 1 ``` You will see that the server is printing aggregated statistics about the dataset distributed amongst clients. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-pandas.html) for a detailed explanation. diff --git a/examples/quickstart-pandas/client.py b/examples/quickstart-pandas/client.py index 3feab3f6a0f4..c2f2605594d5 100644 --- a/examples/quickstart-pandas/client.py +++ b/examples/quickstart-pandas/client.py @@ -1,4 +1,4 @@ -import warnings +import argparse from typing import Dict, List, Tuple import numpy as np @@ -6,10 +6,10 @@ import flwr as fl +from flwr_datasets import FederatedDataset -df = pd.read_csv("./data/client.csv") -column_names = ["sepal length (cm)", "sepal width (cm)"] +column_names = ["sepal_length", "sepal_width"] def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: @@ -19,23 +19,47 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: # Define Flower client class FlowerClient(fl.client.NumPyClient): + def __init__(self, X: pd.DataFrame): + self.X = X + def fit( self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: hist_list = [] # Execute query locally - for c in column_names: - hist = compute_hist(df, c) + for c in self.X.columns: + hist = compute_hist(self.X, c) hist_list.append(hist) return ( hist_list, - len(df), + len(self.X), {}, ) -# Start Flower client -fl.client.start_numpy_client( - server_address="127.0.0.1:8080", - client=FlowerClient(), -) +if __name__ == "__main__": + N_CLIENTS = 2 + + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--node-id", + type=int, + choices=range(0, N_CLIENTS), + required=True, + help="Specifies the node id of artificially partitioned datasets.", + ) + args = parser.parse_args() + partition_id = args.node_id + + # Load the partition data + fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) + + dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:] + # Use just the specified columns + X = dataset[column_names] + + # Start Flower client + fl.client.start_numpy_client( + server_address="127.0.0.1:8080", + client=FlowerClient(X), + ) diff --git a/examples/quickstart-pandas/pyproject.toml b/examples/quickstart-pandas/pyproject.toml index de20eaf61d63..6229210d6488 100644 --- a/examples/quickstart-pandas/pyproject.toml +++ b/examples/quickstart-pandas/pyproject.toml @@ -12,6 +12,6 @@ maintainers = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } numpy = "1.23.2" pandas = "2.0.0" -scikit-learn = "1.3.1" diff --git a/examples/quickstart-pandas/requirements.txt b/examples/quickstart-pandas/requirements.txt index 14308a55faaf..d44a3c6adab9 100644 --- a/examples/quickstart-pandas/requirements.txt +++ b/examples/quickstart-pandas/requirements.txt @@ -1,4 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 numpy==1.23.2 pandas==2.0.0 -scikit-learn==1.3.1 diff --git a/examples/quickstart-pandas/run.sh b/examples/quickstart-pandas/run.sh index 6b85ce30bf45..571fa8bfb3e4 100755 --- a/examples/quickstart-pandas/run.sh +++ b/examples/quickstart-pandas/run.sh @@ -2,13 +2,9 @@ echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -# Download data -mkdir -p ./data -python -c "from sklearn.datasets import load_iris; load_iris(as_frame=True)['data'].to_csv('./data/client.csv')" - for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --node-id ${i} & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pandas/server.py b/examples/quickstart-pandas/server.py index c82304374836..af4c2a796788 100644 --- a/examples/quickstart-pandas/server.py +++ b/examples/quickstart-pandas/server.py @@ -1,5 +1,4 @@ -import pickle -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -9,9 +8,6 @@ EvaluateRes, FitIns, FitRes, - Metrics, - MetricsAggregationFn, - NDArrays, Parameters, Scalar, ndarrays_to_parameters, @@ -23,11 +19,6 @@ class FedAnalytics(Strategy): - def __init__( - self, compute_fns: List[Callable] = None, col_names: List[str] = None - ) -> None: - super().__init__() - def initialize_parameters( self, client_manager: Optional[ClientManager] = None ) -> Optional[Parameters]: diff --git a/examples/quickstart-pytorch-lightning/README.md b/examples/quickstart-pytorch-lightning/README.md index 360efb8f6261..1287b50bca65 100644 --- a/examples/quickstart-pytorch-lightning/README.md +++ b/examples/quickstart-pytorch-lightning/README.md @@ -1 +1,76 @@ -# Flower Examples using PyTorch Lightning +# Flower Example using PyTorch Lightning + +This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch Lightning is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-pytorch-lightning . && rm -rf flower && cd quickstart-pytorch-lightning +``` + +This will create a new directory called `quickstart-pytorch-lightning` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py # client-side code +-- server.py # server-side code (including the strategy) +-- README.md +-- run.sh # runs server, then two clients +-- mnist.py # run a centralised version of this example +``` + +### Installing Dependencies + +Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with PyTorch and Flower + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +python server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the +following commands. + +Start client 1 in the first terminal: + +```shell +python client.py --node-id 0 +``` + +Start client 2 in the second terminal: + +```shell +python client.py --node-id 1 +``` + +You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch-lightning/client.py b/examples/quickstart-pytorch-lightning/client.py index e810d639974d..1dabd5732b9b 100644 --- a/examples/quickstart-pytorch-lightning/client.py +++ b/examples/quickstart-pytorch-lightning/client.py @@ -1,8 +1,14 @@ -import flwr as fl -import mnist -import pytorch_lightning as pl +import argparse from collections import OrderedDict + +import pytorch_lightning as pl import torch +from datasets.utils.logging import disable_progress_bar + +import flwr as fl +import mnist + +disable_progress_bar() class FlowerClient(fl.client.NumPyClient): @@ -50,9 +56,20 @@ def _set_parameters(model, parameters): def main() -> None: + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--node-id", + type=int, + choices=range(0, 10), + required=True, + help="Specifies the artificial data partition", + ) + args = parser.parse_args() + node_id = args.node_id + # Model and data model = mnist.LitAutoEncoder() - train_loader, val_loader, test_loader = mnist.load_data() + train_loader, val_loader, test_loader = mnist.load_data(node_id) # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader) diff --git a/examples/quickstart-pytorch-lightning/mnist.py b/examples/quickstart-pytorch-lightning/mnist.py index c8f8374ecc04..95342f4fb9b3 100644 --- a/examples/quickstart-pytorch-lightning/mnist.py +++ b/examples/quickstart-pytorch-lightning/mnist.py @@ -3,14 +3,13 @@ Source: pytorchlightning.ai (2021/02/04) """ - +from flwr_datasets import FederatedDataset +import pytorch_lightning as pl import torch from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from torchvision import transforms -from torchvision.datasets import MNIST -import pytorch_lightning as pl class LitAutoEncoder(pl.LightningModule): @@ -60,25 +59,56 @@ def _evaluate(self, batch, stage=None): self.log(f"{stage}_loss", loss, prog_bar=True) -def load_data(): - # Training / validation set - trainset = MNIST("", train=True, download=True, transform=transforms.ToTensor()) - mnist_train, mnist_val = random_split(trainset, [55000, 5000]) - train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=16) - val_loader = DataLoader(mnist_val, batch_size=32, shuffle=False, num_workers=16) +def collate_fn(batch): + """Change the dictionary to tuple to keep the exact dataloader behavior.""" + images = [item["image"] for item in batch] + labels = [item["label"] for item in batch] + + images_tensor = torch.stack(images) + labels_tensor = torch.tensor(labels) + + return images_tensor, labels_tensor + + +def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["image"] = [transforms.functional.to_tensor(img) for img in batch["image"]] + return batch + - # Test set - testset = MNIST("", train=False, download=True, transform=transforms.ToTensor()) - test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=16) +def load_data(partition): + fds = FederatedDataset(dataset="mnist", partitioners={"train": 10}) + partition = fds.load_partition(partition, "train") - return train_loader, val_loader, test_loader + partition = partition.with_transform(apply_transforms) + # 20 % for on federated evaluation + partition_full = partition.train_test_split(test_size=0.2) + # 60 % for the federated train and 20 % for the federated validation (both in fit) + partition_train_valid = partition_full["train"].train_test_split(train_size=0.75) + trainloader = DataLoader( + partition_train_valid["train"], + batch_size=32, + shuffle=True, + collate_fn=collate_fn, + num_workers=1, + ) + valloader = DataLoader( + partition_train_valid["test"], + batch_size=32, + collate_fn=collate_fn, + num_workers=1, + ) + testloader = DataLoader( + partition_full["test"], batch_size=32, collate_fn=collate_fn, num_workers=1 + ) + return trainloader, valloader, testloader def main() -> None: """Centralized training.""" # Load data - train_loader, val_loader, test_loader = load_data() + train_loader, val_loader, test_loader = load_data(0) # Load model model = LitAutoEncoder() diff --git a/examples/quickstart-pytorch-lightning/pyproject.toml b/examples/quickstart-pytorch-lightning/pyproject.toml index 0a1e1376b8cb..853ef9c1646f 100644 --- a/examples/quickstart-pytorch-lightning/pyproject.toml +++ b/examples/quickstart-pytorch-lightning/pyproject.toml @@ -12,5 +12,6 @@ authors = ["The Flower Authors "] python = "^3.8" flwr = ">=1.0,<2.0" # flwr = { path = "../../", develop = true } # Development +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } pytorch-lightning = "1.6.0" torchvision = "0.14.1" diff --git a/examples/quickstart-pytorch-lightning/requirements.txt b/examples/quickstart-pytorch-lightning/requirements.txt index 1cd0b31fa0b5..6530dcc8c52c 100644 --- a/examples/quickstart-pytorch-lightning/requirements.txt +++ b/examples/quickstart-pytorch-lightning/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 pytorch_lightning>=1.4.7 torchvision==0.14.1 diff --git a/examples/quickstart-pytorch-lightning/run.sh b/examples/quickstart-pytorch-lightning/run.sh index 2b6507bc154c..60893a9a055b 100755 --- a/examples/quickstart-pytorch-lightning/run.sh +++ b/examples/quickstart-pytorch-lightning/run.sh @@ -4,9 +4,9 @@ echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 1`; do +for i in $(seq 0 1); do echo "Starting client $i" - python client.py & + python client.py --node-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pytorch-lightning/server.py b/examples/quickstart-pytorch-lightning/server.py index 370186ae1d98..a104a1fffd26 100644 --- a/examples/quickstart-pytorch-lightning/server.py +++ b/examples/quickstart-pytorch-lightning/server.py @@ -11,7 +11,7 @@ def main() -> None: # Start Flower server for three rounds of federated learning fl.server.start_server( server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=10), + config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, ) diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index ad57645002f8..1edb42d1ec81 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -99,6 +99,7 @@ def apply_transforms(batch): parser.add_argument( "--node-id", choices=[0, 1, 2], + required=True, type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", ) diff --git a/examples/quickstart-sklearn-tabular/README.md b/examples/quickstart-sklearn-tabular/README.md new file mode 100644 index 000000000000..d62525c96c18 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/README.md @@ -0,0 +1,77 @@ +# Flower Example using scikit-learn + +This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system on +"iris" (tabular) dataset. +It will help you understand how to adapt Flower for use with `scikit-learn`. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to +download, partition and preprocess the dataset. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-sklearn-tabular . && rm -rf flower && cd quickstart-sklearn-tabular +``` + +This will create a new directory called `quickstart-sklearn-tabular` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- utils.py +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `scikit-learn` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with scikit-learn and Flower + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +poetry run python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: + +```shell +poetry run python3 client.py --node-id 0 # node-id should be any of {0,1,2} +``` + +Alternatively you can run all of it in one shell as follows: + +```shell +poetry run python3 server.py & +poetry run python3 client.py --node-id 0 & +poetry run python3 client.py --node-id 1 +``` + +You will see that Flower is starting a federated training. diff --git a/examples/quickstart-sklearn-tabular/client.py b/examples/quickstart-sklearn-tabular/client.py new file mode 100644 index 000000000000..5dc0e88b3c75 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/client.py @@ -0,0 +1,73 @@ +import argparse +import warnings + +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import log_loss + +import flwr as fl +import utils +from flwr_datasets import FederatedDataset + +if __name__ == "__main__": + N_CLIENTS = 3 + + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--node-id", + type=int, + choices=range(0, N_CLIENTS), + required=True, + help="Specifies the artificial data partition", + ) + args = parser.parse_args() + partition_id = args.node_id + + # Load the partition data + fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) + + dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:] + X = dataset[["petal_length", "petal_width", "sepal_length", "sepal_width"]] + y = dataset["species"] + unique_labels = fds.load_full("train").unique("species") + # Split the on edge data: 80% train, 20% test + X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] + y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] + + # Create LogisticRegression Model + model = LogisticRegression( + penalty="l2", + max_iter=1, # local epoch + warm_start=True, # prevent refreshing weights when fitting + ) + + # Setting initial parameters, akin to model.compile for keras models + utils.set_initial_params(model, n_features=X_train.shape[1], n_classes=3) + + # Define Flower client + class IrisClient(fl.client.NumPyClient): + def get_parameters(self, config): # type: ignore + return utils.get_model_parameters(model) + + def fit(self, parameters, config): # type: ignore + utils.set_model_params(model, parameters) + # Ignore convergence failure due to low local epochs + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model.fit(X_train, y_train) + accuracy = model.score(X_train, y_train) + return ( + utils.get_model_parameters(model), + len(X_train), + {"train_accuracy": accuracy}, + ) + + def evaluate(self, parameters, config): # type: ignore + utils.set_model_params(model, parameters) + loss = log_loss(y_test, model.predict_proba(X_test), labels=unique_labels) + accuracy = model.score(X_test, y_test) + return loss, len(X_test), {"test_accuracy": accuracy} + + # Start Flower client + fl.client.start_client( + server_address="0.0.0.0:8080", client=IrisClient().to_client() + ) diff --git a/examples/quickstart-sklearn-tabular/pyproject.toml b/examples/quickstart-sklearn-tabular/pyproject.toml new file mode 100644 index 000000000000..34a78048d3b0 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "sklearn-mnist" +version = "0.1.0" +description = "Federated learning with scikit-learn and Flower" +authors = [ + "The Flower Authors ", + "Kaushik Amar Das " +] + +[tool.poetry.dependencies] +python = "^3.8" +flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +scikit-learn = "^1.3.0" diff --git a/examples/quickstart-sklearn-tabular/requirements.txt b/examples/quickstart-sklearn-tabular/requirements.txt new file mode 100644 index 000000000000..e0f15b31f3f7 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/requirements.txt @@ -0,0 +1,3 @@ +flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +scikit-learn>=1.3.0 diff --git a/examples/quickstart-sklearn-tabular/run.sh b/examples/quickstart-sklearn-tabular/run.sh new file mode 100755 index 000000000000..48cee1b41b74 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in $(seq 0 1); do + echo "Starting client $i" + python client.py --node-id "${i}" & +done + +# This will allow you to use CTRL+C to stop all background processes +trap 'trap - SIGTERM && kill -- -$$' SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/quickstart-sklearn-tabular/server.py b/examples/quickstart-sklearn-tabular/server.py new file mode 100644 index 000000000000..0c779c52a8d6 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/server.py @@ -0,0 +1,19 @@ +import flwr as fl +import utils +from sklearn.linear_model import LogisticRegression + + +# Start Flower server for five rounds of federated learning +if __name__ == "__main__": + model = LogisticRegression() + utils.set_initial_params(model, n_classes=3, n_features=4) + strategy = fl.server.strategy.FedAvg( + min_available_clients=2, + fit_metrics_aggregation_fn=utils.weighted_average, + evaluate_metrics_aggregation_fn=utils.weighted_average, + ) + fl.server.start_server( + server_address="0.0.0.0:8080", + strategy=strategy, + config=fl.server.ServerConfig(num_rounds=25), + ) diff --git a/examples/quickstart-sklearn-tabular/utils.py b/examples/quickstart-sklearn-tabular/utils.py new file mode 100644 index 000000000000..e154f44ef8bf --- /dev/null +++ b/examples/quickstart-sklearn-tabular/utils.py @@ -0,0 +1,75 @@ +from typing import List, Tuple, Dict + +import numpy as np +from sklearn.linear_model import LogisticRegression + +from flwr.common import NDArrays, Metrics, Scalar + + +def get_model_parameters(model: LogisticRegression) -> NDArrays: + """Return the parameters of a sklearn LogisticRegression model.""" + if model.fit_intercept: + params = [ + model.coef_, + model.intercept_, + ] + else: + params = [ + model.coef_, + ] + return params + + +def set_model_params(model: LogisticRegression, params: NDArrays) -> LogisticRegression: + """Set the parameters of a sklean LogisticRegression model.""" + model.coef_ = params[0] + if model.fit_intercept: + model.intercept_ = params[1] + return model + + +def set_initial_params(model: LogisticRegression, n_classes: int, n_features: int): + """Set initial parameters as zeros. + + Required since model params are uninitialized until model.fit is called but server + asks for initial parameters from clients at launch. Refer to + sklearn.linear_model.LogisticRegression documentation for more information. + """ + model.classes_ = np.array([i for i in range(n_classes)]) + + model.coef_ = np.zeros((n_classes, n_features)) + if model.fit_intercept: + model.intercept_ = np.zeros((n_classes,)) + + +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]: + """Compute weighted average. + + It is generic implementation that averages only over floats and ints and drops the + other data types of the Metrics. + """ + print(metrics) + # num_samples_list can represent number of sample or batches depending on the client + num_samples_list = [n_batches for n_batches, _ in metrics] + num_samples_sum = sum(num_samples_list) + metrics_lists: Dict[str, List[float]] = {} + for num_samples, all_metrics_dict in metrics: + # Calculate each metric one by one + for single_metric, value in all_metrics_dict.items(): + if isinstance(value, (float, int)): + metrics_lists[single_metric] = [] + # Just one iteration needed to initialize the keywords + break + + for num_samples, all_metrics_dict in metrics: + # Calculate each metric one by one + for single_metric, value in all_metrics_dict.items(): + # Add weighted metric + if isinstance(value, (float, int)): + metrics_lists[single_metric].append(float(num_samples * value)) + + weighted_metrics: Dict[str, Scalar] = {} + for metric_name, metric_values in metrics_lists.items(): + weighted_metrics[metric_name] = sum(metric_values) / num_samples_sum + + return weighted_metrics diff --git a/examples/quickstart-tensorflow/README.md b/examples/quickstart-tensorflow/README.md index 7ada48797d03..92d38c9340d7 100644 --- a/examples/quickstart-tensorflow/README.md +++ b/examples/quickstart-tensorflow/README.md @@ -1,7 +1,7 @@ # Flower Example using TensorFlow/Keras -This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understanding how to adapt Flower to your use-cases. -Running this example in itself is quite easy. +This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup @@ -50,7 +50,7 @@ pip install -r requirements.txt ## Run Federated Learning with TensorFlow/Keras and Flower -Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: +Afterward, you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: ```shell poetry run python3 server.py @@ -62,7 +62,7 @@ Now you are ready to start the Flower clients which will participate in the lear poetry run python3 client.py ``` -Alternatively you can run all of it in one shell as follows: +Alternatively, you can run all of it in one shell as follows: ```shell poetry run python3 server.py & diff --git a/examples/quickstart-tensorflow/client.py b/examples/quickstart-tensorflow/client.py index fc367e2c3053..d998adbdd899 100644 --- a/examples/quickstart-tensorflow/client.py +++ b/examples/quickstart-tensorflow/client.py @@ -1,16 +1,38 @@ +import argparse import os import flwr as fl import tensorflow as tf - +from flwr_datasets import FederatedDataset # Make TensorFlow log less verbose os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +# Parse arguments +parser = argparse.ArgumentParser(description="Flower") +parser.add_argument( + "--node-id", + type=int, + choices=[0, 1, 2], + required=True, + help="Partition of the dataset (0,1 or 2). " + "The dataset is divided into 3 partitions created artificially.", +) +args = parser.parse_args() + # Load model and data (MobileNetV2, CIFAR-10) model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) -(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() + +# Download and partition dataset +fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) +partition = fds.load_partition(args.node_id, "train") +partition.set_format("numpy") + +# Divide data on each node: 80% train, 20% test +partition = partition.train_test_split(test_size=0.2) +x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] +x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] # Define Flower client diff --git a/examples/quickstart-tensorflow/pyproject.toml b/examples/quickstart-tensorflow/pyproject.toml index 68d4f9aada52..e027a7353181 100644 --- a/examples/quickstart-tensorflow/pyproject.toml +++ b/examples/quickstart-tensorflow/pyproject.toml @@ -11,5 +11,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} diff --git a/examples/quickstart-tensorflow/requirements.txt b/examples/quickstart-tensorflow/requirements.txt index 6420aab25ec8..7f025975cae9 100644 --- a/examples/quickstart-tensorflow/requirements.txt +++ b/examples/quickstart-tensorflow/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" diff --git a/examples/quickstart-tensorflow/run.sh b/examples/quickstart-tensorflow/run.sh index c64f362086aa..439abea8df4b 100755 --- a/examples/quickstart-tensorflow/run.sh +++ b/examples/quickstart-tensorflow/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --node-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-tensorflow/server.py b/examples/quickstart-tensorflow/server.py index 39c350388c1b..fe691a88aba0 100644 --- a/examples/quickstart-tensorflow/server.py +++ b/examples/quickstart-tensorflow/server.py @@ -1,8 +1,25 @@ +from typing import List, Tuple + import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server fl.server.start_server( server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, ) diff --git a/examples/quickstart-xgboost-horizontal/.gitignore b/examples/quickstart-xgboost-horizontal/.gitignore deleted file mode 100644 index 4a6ddf5b9142..000000000000 --- a/examples/quickstart-xgboost-horizontal/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -dataset - diff --git a/examples/quickstart-xgboost-horizontal/README.md b/examples/quickstart-xgboost-horizontal/README.md deleted file mode 100644 index 346a33da7412..000000000000 --- a/examples/quickstart-xgboost-horizontal/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Federated XGBoost in Horizontal Setting (PyTorch) - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb)) - -This example demonstrates a federated XGBoost using Flower with PyTorch. This is a novel method to conduct federated XGBoost in the horizontal setting. It differs from the previous methods in the following ways: - -- We aggregate and conduct federated learning on client tree’s prediction outcomes by sending clients' built XGBoost trees to the server and then sharing to the clients. -- The exchange of privacy-sensitive information (gradients) is not needed. -- The model is a CNN with 1D convolution kernel size = the number of XGBoost trees in the client tree ensembles. -- Using 1D convolution, we make the tree learning rate (a hyperparameter of XGBoost) learnable. - -## Project Setup - -This implementation can be easily run in Google Colab with the button at the top of the README or as a standalone Jupyter notebook, -it will automatically download and extract the example data inside a `dataset` folder and `binary_classification` and `regression` sub-folders. - -## Datasets - -This implementation supports both binary classification and regression datasets in SVM light format, loaded from ([LIBSVM Data](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/)). Simply download the dataset files from the website and put them in the folder location indicated above. diff --git a/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb b/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb deleted file mode 100644 index 4d76e0c26023..000000000000 --- a/examples/quickstart-xgboost-horizontal/code_horizontal.ipynb +++ /dev/null @@ -1,1560 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initialization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 15871, - "status": "ok", - "timestamp": 1670356049976, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "2c588ea0-a383-4461-e633-794e73d0f57a" - }, - "outputs": [], - "source": [ - "import os\n", - "import urllib.request\n", - "import bz2\n", - "import shutil\n", - "\n", - "CLASSIFICATION_PATH = os.path.join(\"dataset\", \"binary_classification\")\n", - "REGRESSION_PATH = os.path.join(\"dataset\", \"regression\")\n", - "\n", - "if not os.path.exists(CLASSIFICATION_PATH):\n", - " os.makedirs(CLASSIFICATION_PATH)\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/cod-rna\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'cod-rna')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/cod-rna.t\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'cod-rna.t')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/cod-rna.r\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'cod-rna.r')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/ijcnn1.t.bz2\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'ijcnn1.t.bz2')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/ijcnn1.tr.bz2\",\n", - " f\"{os.path.join(CLASSIFICATION_PATH, 'ijcnn1.tr.bz2')}\",\n", - " )\n", - " for filepath in os.listdir(CLASSIFICATION_PATH):\n", - " if filepath[-3:] == \"bz2\":\n", - " abs_filepath = os.path.join(CLASSIFICATION_PATH, filepath)\n", - " with bz2.BZ2File(abs_filepath) as fr, open(abs_filepath[:-4], \"wb\") as fw:\n", - " shutil.copyfileobj(fr, fw)\n", - "\n", - "if not os.path.exists(REGRESSION_PATH):\n", - " os.makedirs(REGRESSION_PATH)\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/eunite2001\",\n", - " f\"{os.path.join(REGRESSION_PATH, 'eunite2001')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/eunite2001.t\",\n", - " f\"{os.path.join(REGRESSION_PATH, 'eunite2001.t')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.bz2\",\n", - " f\"{os.path.join(REGRESSION_PATH, 'YearPredictionMSD.bz2')}\",\n", - " )\n", - " urllib.request.urlretrieve(\n", - " \"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.t.bz2\",\n", - " f\"{os.path.join(REGRESSION_PATH, 'YearPredictionMSD.t.bz2')}\",\n", - " )\n", - " for filepath in os.listdir(REGRESSION_PATH):\n", - " if filepath[-3:] == \"bz2\":\n", - " abs_filepath = os.path.join(REGRESSION_PATH, filepath)\n", - " with bz2.BZ2File(abs_filepath) as fr, open(abs_filepath[:-4], \"wb\") as fw:\n", - " shutil.copyfileobj(fr, fw)\n", - "\n", - "\n", - "!nvidia-smi\n", - "!pip install matplotlib scikit-learn tqdm torch torchmetrics torchsummary xgboost\n", - "!pip install -U \"flwr-nightly[simulation]\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Import relevant modules" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 7, - "status": "ok", - "timestamp": 1670356049977, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "5289e33e-e18e-491b-d536-6b1052598994" - }, - "outputs": [], - "source": [ - "import xgboost as xgb\n", - "from xgboost import XGBClassifier, XGBRegressor\n", - "from sklearn.metrics import mean_squared_error, accuracy_score\n", - "from sklearn.datasets import load_svmlight_file\n", - "\n", - "import numpy as np\n", - "import torch, torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torchvision\n", - "from torchmetrics import Accuracy, MeanSquaredError\n", - "from tqdm import trange, tqdm\n", - "from torchsummary import summary\n", - "from torch.utils.data import DataLoader, Dataset, random_split\n", - "\n", - "print(\"Imported modules.\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Import Flower relevant modules for Federated XGBoost" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import flwr as fl\n", - "from flwr.common.typing import Parameters\n", - "from collections import OrderedDict\n", - "from typing import Any, Dict, List, Optional, Tuple, Union\n", - "from flwr.common import NDArray, NDArrays\n", - "\n", - "print(\"Imported modules.\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define utility function for xgboost trees" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib import pyplot as plt # pylint: disable=E0401\n", - "\n", - "\n", - "def plot_xgbtree(tree: Union[XGBClassifier, XGBRegressor], n_tree: int) -> None:\n", - " \"\"\"Visualize the built xgboost tree.\"\"\"\n", - " xgb.plot_tree(tree, num_trees=n_tree)\n", - " plt.rcParams[\"figure.figsize\"] = [50, 10]\n", - " plt.show()\n", - "\n", - "\n", - "def construct_tree(\n", - " dataset: Dataset, label: NDArray, n_estimators: int, tree_type: str\n", - ") -> Union[XGBClassifier, XGBRegressor]:\n", - " \"\"\"Construct a xgboost tree form tabular dataset.\"\"\"\n", - " if tree_type == \"BINARY\":\n", - " tree = xgb.XGBClassifier(\n", - " objective=\"binary:logistic\",\n", - " learning_rate=0.1,\n", - " max_depth=8,\n", - " n_estimators=n_estimators,\n", - " subsample=0.8,\n", - " colsample_bylevel=1,\n", - " colsample_bynode=1,\n", - " colsample_bytree=1,\n", - " alpha=5,\n", - " gamma=5,\n", - " num_parallel_tree=1,\n", - " min_child_weight=1,\n", - " )\n", - "\n", - " elif tree_type == \"REG\":\n", - " tree = xgb.XGBRegressor(\n", - " objective=\"reg:squarederror\",\n", - " learning_rate=0.1,\n", - " max_depth=8,\n", - " n_estimators=n_estimators,\n", - " subsample=0.8,\n", - " colsample_bylevel=1,\n", - " colsample_bynode=1,\n", - " colsample_bytree=1,\n", - " alpha=5,\n", - " gamma=5,\n", - " num_parallel_tree=1,\n", - " min_child_weight=1,\n", - " )\n", - "\n", - " tree.fit(dataset, label)\n", - " return tree\n", - "\n", - "\n", - "def construct_tree_from_loader(\n", - " dataset_loader: DataLoader, n_estimators: int, tree_type: str\n", - ") -> Union[XGBClassifier, XGBRegressor]:\n", - " \"\"\"Construct a xgboost tree form tabular dataset loader.\"\"\"\n", - " for dataset in dataset_loader:\n", - " data, label = dataset[0], dataset[1]\n", - " return construct_tree(data, label, n_estimators, tree_type)\n", - "\n", - "\n", - "def single_tree_prediction(\n", - " tree: Union[XGBClassifier, XGBRegressor], n_tree: int, dataset: NDArray\n", - ") -> Optional[NDArray]:\n", - " \"\"\"Extract the prediction result of a single tree in the xgboost tree\n", - " ensemble.\"\"\"\n", - " # How to access a single tree\n", - " # https://github.com/bmreiniger/datascience.stackexchange/blob/master/57905.ipynb\n", - " num_t = len(tree.get_booster().get_dump())\n", - " if n_tree > num_t:\n", - " print(\n", - " \"The tree index to be extracted is larger than the total number of trees.\"\n", - " )\n", - " return None\n", - "\n", - " return tree.predict( # type: ignore\n", - " dataset, iteration_range=(n_tree, n_tree + 1), output_margin=True\n", - " )\n", - "\n", - "\n", - "def tree_encoding( # pylint: disable=R0914\n", - " trainloader: DataLoader,\n", - " client_trees: Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ],\n", - " client_tree_num: int,\n", - " client_num: int,\n", - ") -> Optional[Tuple[NDArray, NDArray]]:\n", - " \"\"\"Transform the tabular dataset into prediction results using the\n", - " aggregated xgboost tree ensembles from all clients.\"\"\"\n", - " if trainloader is None:\n", - " return None\n", - "\n", - " for local_dataset in trainloader:\n", - " x_train, y_train = local_dataset[0], local_dataset[1]\n", - "\n", - " x_train_enc = np.zeros((x_train.shape[0], client_num * client_tree_num))\n", - " x_train_enc = np.array(x_train_enc, copy=True)\n", - "\n", - " temp_trees: Any = None\n", - " if isinstance(client_trees, list) is False:\n", - " temp_trees = [client_trees[0]] * client_num\n", - " elif isinstance(client_trees, list) and len(client_trees) != client_num:\n", - " temp_trees = [client_trees[0][0]] * client_num\n", - " else:\n", - " cids = []\n", - " temp_trees = []\n", - " for i, _ in enumerate(client_trees):\n", - " temp_trees.append(client_trees[i][0]) # type: ignore\n", - " cids.append(client_trees[i][1]) # type: ignore\n", - " sorted_index = np.argsort(np.asarray(cids))\n", - " temp_trees = np.asarray(temp_trees)[sorted_index]\n", - "\n", - " for i, _ in enumerate(temp_trees):\n", - " for j in range(client_tree_num):\n", - " x_train_enc[:, i * client_tree_num + j] = single_tree_prediction(\n", - " temp_trees[i], j, x_train\n", - " )\n", - "\n", - " x_train_enc32: Any = np.float32(x_train_enc)\n", - " y_train32: Any = np.float32(y_train)\n", - "\n", - " x_train_enc32, y_train32 = torch.from_numpy(\n", - " np.expand_dims(x_train_enc32, axis=1) # type: ignore\n", - " ), torch.from_numpy(\n", - " np.expand_dims(y_train32, axis=-1) # type: ignore\n", - " )\n", - " return x_train_enc32, y_train32" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Manually download and load the tabular dataset from LIBSVM data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 26613, - "status": "ok", - "timestamp": 1670356076585, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "22843504-faf0-44cf-aedd-1df8d0ec87a6" - }, - "outputs": [], - "source": [ - "# Datasets can be downloaded from LIBSVM Data: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/\n", - "binary_train = [\"cod-rna.t\", \"cod-rna\", \"ijcnn1.t\"]\n", - "binary_test = [\"cod-rna.r\", \"cod-rna.t\", \"ijcnn1.tr\"]\n", - "reg_train = [\"eunite2001\", \"YearPredictionMSD\"]\n", - "reg_test = [\"eunite2001.t\", \"YearPredictionMSD.t\"]\n", - "\n", - "# Define the type of training task. Binary classification: BINARY; Regression: REG\n", - "task_types = [\"BINARY\", \"REG\"]\n", - "task_type = task_types[0]\n", - "\n", - "# Select the downloaded training and test dataset\n", - "if task_type == \"BINARY\":\n", - " dataset_path = \"dataset/binary_classification/\"\n", - " train = binary_train[0]\n", - " test = binary_test[0]\n", - "elif task_type == \"REG\":\n", - " dataset_path = \"dataset/regression/\"\n", - " train = reg_train[0]\n", - " test = reg_test[0]\n", - "\n", - "data_train = load_svmlight_file(dataset_path + train, zero_based=False)\n", - "data_test = load_svmlight_file(dataset_path + test, zero_based=False)\n", - "\n", - "print(\"Task type selected is: \" + task_type)\n", - "print(\"Training dataset is: \" + train)\n", - "print(\"Test dataset is: \" + test)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preprocess the tabular dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class TreeDataset(Dataset):\n", - " def __init__(self, data: NDArray, labels: NDArray) -> None:\n", - " self.labels = labels\n", - " self.data = data\n", - "\n", - " def __len__(self) -> int:\n", - " return len(self.labels)\n", - "\n", - " def __getitem__(self, idx: int) -> Dict[int, NDArray]:\n", - " label = self.labels[idx]\n", - " data = self.data[idx, :]\n", - " sample = {0: data, 1: label}\n", - " return sample\n", - "\n", - "\n", - "X_train = data_train[0].toarray()\n", - "y_train = data_train[1]\n", - "X_test = data_test[0].toarray()\n", - "y_test = data_test[1]\n", - "X_train.flags.writeable = True\n", - "y_train.flags.writeable = True\n", - "X_test.flags.writeable = True\n", - "y_test.flags.writeable = True\n", - "\n", - "# If the feature dimensions of the trainset and testset do not agree,\n", - "# specify n_features in the load_svmlight_file function in the above cell.\n", - "# https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_svmlight_file.html\n", - "print(\"Feature dimension of the dataset:\", X_train.shape[1])\n", - "print(\"Size of the trainset:\", X_train.shape[0])\n", - "print(\"Size of the testset:\", X_test.shape[0])\n", - "assert X_train.shape[1] == X_test.shape[1]\n", - "\n", - "if task_type == \"BINARY\":\n", - " y_train[y_train == -1] = 0\n", - " y_test[y_test == -1] = 0\n", - "\n", - "trainset = TreeDataset(np.array(X_train, copy=True), np.array(y_train, copy=True))\n", - "testset = TreeDataset(np.array(X_test, copy=True), np.array(y_test, copy=True))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Conduct tabular dataset partition for Federated Learning" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_dataloader(\n", - " dataset: Dataset, partition: str, batch_size: Union[int, str]\n", - ") -> DataLoader:\n", - " if batch_size == \"whole\":\n", - " batch_size = len(dataset)\n", - " return DataLoader(\n", - " dataset, batch_size=batch_size, pin_memory=True, shuffle=(partition == \"train\")\n", - " )\n", - "\n", - "\n", - "# https://github.com/adap/flower\n", - "def do_fl_partitioning(\n", - " trainset: Dataset,\n", - " testset: Dataset,\n", - " pool_size: int,\n", - " batch_size: Union[int, str],\n", - " val_ratio: float = 0.0,\n", - ") -> Tuple[DataLoader, DataLoader, DataLoader]:\n", - " # Split training set into `num_clients` partitions to simulate different local datasets\n", - " partition_size = len(trainset) // pool_size\n", - " lengths = [partition_size] * pool_size\n", - " if sum(lengths) != len(trainset):\n", - " lengths[-1] = len(trainset) - sum(lengths[0:-1])\n", - " datasets = random_split(trainset, lengths, torch.Generator().manual_seed(0))\n", - "\n", - " # Split each partition into train/val and create DataLoader\n", - " trainloaders = []\n", - " valloaders = []\n", - " for ds in datasets:\n", - " len_val = int(len(ds) * val_ratio)\n", - " len_train = len(ds) - len_val\n", - " lengths = [len_train, len_val]\n", - " ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(0))\n", - " trainloaders.append(get_dataloader(ds_train, \"train\", batch_size))\n", - " if len_val != 0:\n", - " valloaders.append(get_dataloader(ds_val, \"val\", batch_size))\n", - " else:\n", - " valloaders = None\n", - " testloader = get_dataloader(testset, \"test\", batch_size)\n", - " return trainloaders, valloaders, testloader" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define global variables for Federated XGBoost Learning" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# The number of clients participated in the federated learning\n", - "client_num = 5\n", - "\n", - "# The number of XGBoost trees in the tree ensemble that will be built for each client\n", - "client_tree_num = 500 // client_num" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Build global XGBoost tree for comparison" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 1080216, - "status": "ok", - "timestamp": 1670357156788, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "d56f2821-5cd5-49ff-c5dc-f8d088eed799" - }, - "outputs": [], - "source": [ - "global_tree = construct_tree(X_train, y_train, client_tree_num, task_type)\n", - "preds_train = global_tree.predict(X_train)\n", - "preds_test = global_tree.predict(X_test)\n", - "\n", - "if task_type == \"BINARY\":\n", - " result_train = accuracy_score(y_train, preds_train)\n", - " result_test = accuracy_score(y_test, preds_test)\n", - " print(\"Global XGBoost Training Accuracy: %f\" % (result_train))\n", - " print(\"Global XGBoost Testing Accuracy: %f\" % (result_test))\n", - "elif task_type == \"REG\":\n", - " result_train = mean_squared_error(y_train, preds_train)\n", - " result_test = mean_squared_error(y_test, preds_test)\n", - " print(\"Global XGBoost Training MSE: %f\" % (result_train))\n", - " print(\"Global XGBoost Testing MSE: %f\" % (result_test))\n", - "\n", - "print(global_tree)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Simulate local XGBoost trees on clients for comparison" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "executionInfo": { - "elapsed": 242310, - "status": "ok", - "timestamp": 1670357399084, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "0739df9f-84de-4749-8de1-7bd7c6a32ccc" - }, - "outputs": [], - "source": [ - "client_trees_comparison = []\n", - "trainloaders, _, testloader = do_fl_partitioning(\n", - " trainset, testset, pool_size=client_num, batch_size=\"whole\", val_ratio=0.0\n", - ")\n", - "\n", - "for i, trainloader in enumerate(trainloaders):\n", - " for local_dataset in trainloader:\n", - " local_X_train, local_y_train = local_dataset[0], local_dataset[1]\n", - " tree = construct_tree(local_X_train, local_y_train, client_tree_num, task_type)\n", - " client_trees_comparison.append(tree)\n", - "\n", - " preds_train = client_trees_comparison[-1].predict(local_X_train)\n", - " preds_test = client_trees_comparison[-1].predict(X_test)\n", - "\n", - " if task_type == \"BINARY\":\n", - " result_train = accuracy_score(local_y_train, preds_train)\n", - " result_test = accuracy_score(y_test, preds_test)\n", - " print(\"Local Client %d XGBoost Training Accuracy: %f\" % (i, result_train))\n", - " print(\"Local Client %d XGBoost Testing Accuracy: %f\" % (i, result_test))\n", - " elif task_type == \"REG\":\n", - " result_train = mean_squared_error(local_y_train, preds_train)\n", - " result_test = mean_squared_error(y_test, preds_test)\n", - " print(\"Local Client %d XGBoost Training MSE: %f\" % (i, result_train))\n", - " print(\"Local Client %d XGBoost Testing MSE: %f\" % (i, result_test))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Centralized Federated XGBoost\n", - "#### Create 1D convolutional neural network on trees prediction results. \n", - "#### 1D kernel size == client_tree_num\n", - "#### Make the learning rate of the tree ensembles learnable." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "executionInfo": { - "elapsed": 38, - "status": "ok", - "timestamp": 1670363021675, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - } - }, - "outputs": [], - "source": [ - "class CNN(nn.Module):\n", - " def __init__(self, n_channel: int = 64) -> None:\n", - " super(CNN, self).__init__()\n", - " n_out = 1\n", - " self.task_type = task_type\n", - " self.conv1d = nn.Conv1d(\n", - " 1, n_channel, kernel_size=client_tree_num, stride=client_tree_num, padding=0\n", - " )\n", - " self.layer_direct = nn.Linear(n_channel * client_num, n_out)\n", - " self.ReLU = nn.ReLU()\n", - " self.Sigmoid = nn.Sigmoid()\n", - " self.Identity = nn.Identity()\n", - "\n", - " # Add weight initialization\n", - " for layer in self.modules():\n", - " if isinstance(layer, nn.Linear):\n", - " nn.init.kaiming_uniform_(\n", - " layer.weight, mode=\"fan_in\", nonlinearity=\"relu\"\n", - " )\n", - "\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", - " x = self.ReLU(self.conv1d(x))\n", - " x = x.flatten(start_dim=1)\n", - " x = self.ReLU(x)\n", - " if self.task_type == \"BINARY\":\n", - " x = self.Sigmoid(self.layer_direct(x))\n", - " elif self.task_type == \"REG\":\n", - " x = self.Identity(self.layer_direct(x))\n", - " return x\n", - "\n", - " def get_weights(self) -> fl.common.NDArrays:\n", - " \"\"\"Get model weights as a list of NumPy ndarrays.\"\"\"\n", - " return [\n", - " np.array(val.cpu().numpy(), copy=True)\n", - " for _, val in self.state_dict().items()\n", - " ]\n", - "\n", - " def set_weights(self, weights: fl.common.NDArrays) -> None:\n", - " \"\"\"Set model weights from a list of NumPy ndarrays.\"\"\"\n", - " layer_dict = {}\n", - " for k, v in zip(self.state_dict().keys(), weights):\n", - " if v.ndim != 0:\n", - " layer_dict[k] = torch.Tensor(np.array(v, copy=True))\n", - " state_dict = OrderedDict(layer_dict)\n", - " self.load_state_dict(state_dict, strict=True)\n", - "\n", - "\n", - "def train(\n", - " task_type: str,\n", - " net: CNN,\n", - " trainloader: DataLoader,\n", - " device: torch.device,\n", - " num_iterations: int,\n", - " log_progress: bool = True,\n", - ") -> Tuple[float, float, int]:\n", - " # Define loss and optimizer\n", - " if task_type == \"BINARY\":\n", - " criterion = nn.BCELoss()\n", - " elif task_type == \"REG\":\n", - " criterion = nn.MSELoss()\n", - " # optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-6)\n", - " optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999))\n", - "\n", - " def cycle(iterable):\n", - " \"\"\"Repeats the contents of the train loader, in case it gets exhausted in 'num_iterations'.\"\"\"\n", - " while True:\n", - " for x in iterable:\n", - " yield x\n", - "\n", - " # Train the network\n", - " net.train()\n", - " total_loss, total_result, n_samples = 0.0, 0.0, 0\n", - " pbar = (\n", - " tqdm(iter(cycle(trainloader)), total=num_iterations, desc=f\"TRAIN\")\n", - " if log_progress\n", - " else iter(cycle(trainloader))\n", - " )\n", - "\n", - " # Unusually, this training is formulated in terms of number of updates/iterations/batches processed\n", - " # by the network. This will be helpful later on, when partitioning the data across clients: resulting\n", - " # in differences between dataset sizes and hence inconsistent numbers of updates per 'epoch'.\n", - " for i, data in zip(range(num_iterations), pbar):\n", - " tree_outputs, labels = data[0].to(device), data[1].to(device)\n", - " optimizer.zero_grad()\n", - "\n", - " outputs = net(tree_outputs)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # Collected training loss and accuracy statistics\n", - " total_loss += loss.item()\n", - " n_samples += labels.size(0)\n", - "\n", - " if task_type == \"BINARY\":\n", - " acc = Accuracy(task=\"binary\")(outputs, labels.type(torch.int))\n", - " total_result += acc * labels.size(0)\n", - " elif task_type == \"REG\":\n", - " mse = MeanSquaredError()(outputs, labels.type(torch.int))\n", - " total_result += mse * labels.size(0)\n", - "\n", - " if log_progress:\n", - " if task_type == \"BINARY\":\n", - " pbar.set_postfix(\n", - " {\n", - " \"train_loss\": total_loss / n_samples,\n", - " \"train_acc\": total_result / n_samples,\n", - " }\n", - " )\n", - " elif task_type == \"REG\":\n", - " pbar.set_postfix(\n", - " {\n", - " \"train_loss\": total_loss / n_samples,\n", - " \"train_mse\": total_result / n_samples,\n", - " }\n", - " )\n", - " if log_progress:\n", - " print(\"\\n\")\n", - "\n", - " return total_loss / n_samples, total_result / n_samples, n_samples\n", - "\n", - "\n", - "def test(\n", - " task_type: str,\n", - " net: CNN,\n", - " testloader: DataLoader,\n", - " device: torch.device,\n", - " log_progress: bool = True,\n", - ") -> Tuple[float, float, int]:\n", - " \"\"\"Evaluates the network on test data.\"\"\"\n", - " if task_type == \"BINARY\":\n", - " criterion = nn.BCELoss()\n", - " elif task_type == \"REG\":\n", - " criterion = nn.MSELoss()\n", - "\n", - " total_loss, total_result, n_samples = 0.0, 0.0, 0\n", - " net.eval()\n", - " with torch.no_grad():\n", - " pbar = tqdm(testloader, desc=\"TEST\") if log_progress else testloader\n", - " for data in pbar:\n", - " tree_outputs, labels = data[0].to(device), data[1].to(device)\n", - " outputs = net(tree_outputs)\n", - "\n", - " # Collected testing loss and accuracy statistics\n", - " total_loss += criterion(outputs, labels).item()\n", - " n_samples += labels.size(0)\n", - "\n", - " if task_type == \"BINARY\":\n", - " acc = Accuracy(task=\"binary\")(\n", - " outputs.cpu(), labels.type(torch.int).cpu()\n", - " )\n", - " total_result += acc * labels.size(0)\n", - " elif task_type == \"REG\":\n", - " mse = MeanSquaredError()(outputs.cpu(), labels.type(torch.int).cpu())\n", - " total_result += mse * labels.size(0)\n", - "\n", - " if log_progress:\n", - " print(\"\\n\")\n", - "\n", - " return total_loss / n_samples, total_result / n_samples, n_samples" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create Flower custom client\n", - "## Import Flower custom client relevant modules" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Flower client\n", - "from flwr.common import (\n", - " EvaluateIns,\n", - " EvaluateRes,\n", - " FitIns,\n", - " FitRes,\n", - " GetPropertiesIns,\n", - " GetPropertiesRes,\n", - " GetParametersIns,\n", - " GetParametersRes,\n", - " Status,\n", - " Code,\n", - " parameters_to_ndarrays,\n", - " ndarrays_to_parameters,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "executionInfo": { - "elapsed": 36, - "status": "ok", - "timestamp": 1670363021676, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - } - }, - "outputs": [], - "source": [ - "def tree_encoding_loader(\n", - " dataloader: DataLoader,\n", - " batch_size: int,\n", - " client_trees: Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ],\n", - " client_tree_num: int,\n", - " client_num: int,\n", - ") -> DataLoader:\n", - " encoding = tree_encoding(dataloader, client_trees, client_tree_num, client_num)\n", - " if encoding is None:\n", - " return None\n", - " data, labels = encoding\n", - " tree_dataset = TreeDataset(data, labels)\n", - " return get_dataloader(tree_dataset, \"tree\", batch_size)\n", - "\n", - "\n", - "class FL_Client(fl.client.Client):\n", - " def __init__(\n", - " self,\n", - " task_type: str,\n", - " trainloader: DataLoader,\n", - " valloader: DataLoader,\n", - " client_tree_num: int,\n", - " client_num: int,\n", - " cid: str,\n", - " log_progress: bool = False,\n", - " ):\n", - " \"\"\"\n", - " Creates a client for training `network.Net` on tabular dataset.\n", - " \"\"\"\n", - " self.task_type = task_type\n", - " self.cid = cid\n", - " self.tree = construct_tree_from_loader(trainloader, client_tree_num, task_type)\n", - " self.trainloader_original = trainloader\n", - " self.valloader_original = valloader\n", - " self.trainloader = None\n", - " self.valloader = None\n", - " self.client_tree_num = client_tree_num\n", - " self.client_num = client_num\n", - " self.properties = {\"tensor_type\": \"numpy.ndarray\"}\n", - " self.log_progress = log_progress\n", - "\n", - " # instantiate model\n", - " self.net = CNN()\n", - "\n", - " # determine device\n", - " self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - " def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:\n", - " return GetPropertiesRes(properties=self.properties)\n", - "\n", - " def get_parameters(\n", - " self, ins: GetParametersIns\n", - " ) -> Tuple[\n", - " GetParametersRes, Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]\n", - " ]:\n", - " return [\n", - " GetParametersRes(\n", - " status=Status(Code.OK, \"\"),\n", - " parameters=ndarrays_to_parameters(self.net.get_weights()),\n", - " ),\n", - " (self.tree, int(self.cid)),\n", - " ]\n", - "\n", - " def set_parameters(\n", - " self,\n", - " parameters: Tuple[\n", - " Parameters,\n", - " Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ],\n", - " ],\n", - " ) -> Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ]:\n", - " self.net.set_weights(parameters_to_ndarrays(parameters[0]))\n", - " return parameters[1]\n", - "\n", - " def fit(self, fit_params: FitIns) -> FitRes:\n", - " # Process incoming request to train\n", - " num_iterations = fit_params.config[\"num_iterations\"]\n", - " batch_size = fit_params.config[\"batch_size\"]\n", - " aggregated_trees = self.set_parameters(fit_params.parameters)\n", - "\n", - " if type(aggregated_trees) is list:\n", - " print(\"Client \" + self.cid + \": recieved\", len(aggregated_trees), \"trees\")\n", - " else:\n", - " print(\"Client \" + self.cid + \": only had its own tree\")\n", - " self.trainloader = tree_encoding_loader(\n", - " self.trainloader_original,\n", - " batch_size,\n", - " aggregated_trees,\n", - " self.client_tree_num,\n", - " self.client_num,\n", - " )\n", - " self.valloader = tree_encoding_loader(\n", - " self.valloader_original,\n", - " batch_size,\n", - " aggregated_trees,\n", - " self.client_tree_num,\n", - " self.client_num,\n", - " )\n", - "\n", - " # num_iterations = None special behaviour: train(...) runs for a single epoch, however many updates it may be\n", - " num_iterations = num_iterations or len(self.trainloader)\n", - "\n", - " # Train the model\n", - " print(f\"Client {self.cid}: training for {num_iterations} iterations/updates\")\n", - " self.net.to(self.device)\n", - " train_loss, train_result, num_examples = train(\n", - " self.task_type,\n", - " self.net,\n", - " self.trainloader,\n", - " device=self.device,\n", - " num_iterations=num_iterations,\n", - " log_progress=self.log_progress,\n", - " )\n", - " print(\n", - " f\"Client {self.cid}: training round complete, {num_examples} examples processed\"\n", - " )\n", - "\n", - " # Return training information: model, number of examples processed and metrics\n", - " if self.task_type == \"BINARY\":\n", - " return FitRes(\n", - " status=Status(Code.OK, \"\"),\n", - " parameters=self.get_parameters(fit_params.config),\n", - " num_examples=num_examples,\n", - " metrics={\"loss\": train_loss, \"accuracy\": train_result},\n", - " )\n", - " elif self.task_type == \"REG\":\n", - " return FitRes(\n", - " status=Status(Code.OK, \"\"),\n", - " parameters=self.get_parameters(fit_params.config),\n", - " num_examples=num_examples,\n", - " metrics={\"loss\": train_loss, \"mse\": train_result},\n", - " )\n", - "\n", - " def evaluate(self, eval_params: EvaluateIns) -> EvaluateRes:\n", - " # Process incoming request to evaluate\n", - " self.set_parameters(eval_params.parameters)\n", - "\n", - " # Evaluate the model\n", - " self.net.to(self.device)\n", - " loss, result, num_examples = test(\n", - " self.task_type,\n", - " self.net,\n", - " self.valloader,\n", - " device=self.device,\n", - " log_progress=self.log_progress,\n", - " )\n", - "\n", - " # Return evaluation information\n", - " if self.task_type == \"BINARY\":\n", - " print(\n", - " f\"Client {self.cid}: evaluation on {num_examples} examples: loss={loss:.4f}, accuracy={result:.4f}\"\n", - " )\n", - " return EvaluateRes(\n", - " status=Status(Code.OK, \"\"),\n", - " loss=loss,\n", - " num_examples=num_examples,\n", - " metrics={\"accuracy\": result},\n", - " )\n", - " elif self.task_type == \"REG\":\n", - " print(\n", - " f\"Client {self.cid}: evaluation on {num_examples} examples: loss={loss:.4f}, mse={result:.4f}\"\n", - " )\n", - " return EvaluateRes(\n", - " status=Status(Code.OK, \"\"),\n", - " loss=loss,\n", - " num_examples=num_examples,\n", - " metrics={\"mse\": result},\n", - " )" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create Flower custom server\n", - "## Import Flower custom server relevant modules" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Flower server\n", - "import functools\n", - "from flwr.server.strategy import FedXgbNnAvg\n", - "from flwr.server.app import ServerConfig\n", - "\n", - "import timeit\n", - "from logging import DEBUG, INFO\n", - "from typing import Dict, List, Optional, Tuple, Union\n", - "\n", - "from flwr.common import DisconnectRes, Parameters, ReconnectIns, Scalar\n", - "from flwr.common.logger import log\n", - "from flwr.common.typing import GetParametersIns\n", - "from flwr.server.client_manager import ClientManager, SimpleClientManager\n", - "from flwr.server.client_proxy import ClientProxy\n", - "from flwr.server.history import History\n", - "from flwr.server.strategy import Strategy\n", - "from flwr.server.server import (\n", - " reconnect_clients,\n", - " reconnect_client,\n", - " fit_clients,\n", - " fit_client,\n", - " _handle_finished_future_after_fit,\n", - " evaluate_clients,\n", - " evaluate_client,\n", - " _handle_finished_future_after_evaluate,\n", - ")\n", - "\n", - "FitResultsAndFailures = Tuple[\n", - " List[Tuple[ClientProxy, FitRes]],\n", - " List[Union[Tuple[ClientProxy, FitRes], BaseException]],\n", - "]\n", - "EvaluateResultsAndFailures = Tuple[\n", - " List[Tuple[ClientProxy, EvaluateRes]],\n", - " List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class FL_Server(fl.server.Server):\n", - " \"\"\"Flower server.\"\"\"\n", - "\n", - " def __init__(\n", - " self, *, client_manager: ClientManager, strategy: Optional[Strategy] = None\n", - " ) -> None:\n", - " self._client_manager: ClientManager = client_manager\n", - " self.parameters: Parameters = Parameters(\n", - " tensors=[], tensor_type=\"numpy.ndarray\"\n", - " )\n", - " self.strategy: Strategy = strategy\n", - " self.max_workers: Optional[int] = None\n", - "\n", - " # pylint: disable=too-many-locals\n", - " def fit(self, num_rounds: int, timeout: Optional[float]) -> History:\n", - " \"\"\"Run federated averaging for a number of rounds.\"\"\"\n", - " history = History()\n", - "\n", - " # Initialize parameters\n", - " log(INFO, \"Initializing global parameters\")\n", - " self.parameters = self._get_initial_parameters(timeout=timeout)\n", - "\n", - " log(INFO, \"Evaluating initial parameters\")\n", - " res = self.strategy.evaluate(0, parameters=self.parameters)\n", - " if res is not None:\n", - " log(\n", - " INFO,\n", - " \"initial parameters (loss, other metrics): %s, %s\",\n", - " res[0],\n", - " res[1],\n", - " )\n", - " history.add_loss_centralized(server_round=0, loss=res[0])\n", - " history.add_metrics_centralized(server_round=0, metrics=res[1])\n", - "\n", - " # Run federated learning for num_rounds\n", - " log(INFO, \"FL starting\")\n", - " start_time = timeit.default_timer()\n", - "\n", - " for current_round in range(1, num_rounds + 1):\n", - " # Train model and replace previous global model\n", - " res_fit = self.fit_round(server_round=current_round, timeout=timeout)\n", - " if res_fit:\n", - " parameters_prime, _, _ = res_fit # fit_metrics_aggregated\n", - " if parameters_prime:\n", - " self.parameters = parameters_prime\n", - "\n", - " # Evaluate model using strategy implementation\n", - " res_cen = self.strategy.evaluate(current_round, parameters=self.parameters)\n", - " if res_cen is not None:\n", - " loss_cen, metrics_cen = res_cen\n", - " log(\n", - " INFO,\n", - " \"fit progress: (%s, %s, %s, %s)\",\n", - " current_round,\n", - " loss_cen,\n", - " metrics_cen,\n", - " timeit.default_timer() - start_time,\n", - " )\n", - " history.add_loss_centralized(server_round=current_round, loss=loss_cen)\n", - " history.add_metrics_centralized(\n", - " server_round=current_round, metrics=metrics_cen\n", - " )\n", - "\n", - " # Evaluate model on a sample of available clients\n", - " res_fed = self.evaluate_round(server_round=current_round, timeout=timeout)\n", - " if res_fed:\n", - " loss_fed, evaluate_metrics_fed, _ = res_fed\n", - " if loss_fed:\n", - " history.add_loss_distributed(\n", - " server_round=current_round, loss=loss_fed\n", - " )\n", - " history.add_metrics_distributed(\n", - " server_round=current_round, metrics=evaluate_metrics_fed\n", - " )\n", - "\n", - " # Bookkeeping\n", - " end_time = timeit.default_timer()\n", - " elapsed = end_time - start_time\n", - " log(INFO, \"FL finished in %s\", elapsed)\n", - " return history\n", - "\n", - " def evaluate_round(\n", - " self,\n", - " server_round: int,\n", - " timeout: Optional[float],\n", - " ) -> Optional[\n", - " Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]\n", - " ]:\n", - " \"\"\"Validate current global model on a number of clients.\"\"\"\n", - "\n", - " # Get clients and their respective instructions from strategy\n", - " client_instructions = self.strategy.configure_evaluate(\n", - " server_round=server_round,\n", - " parameters=self.parameters,\n", - " client_manager=self._client_manager,\n", - " )\n", - " if not client_instructions:\n", - " log(INFO, \"evaluate_round %s: no clients selected, cancel\", server_round)\n", - " return None\n", - " log(\n", - " DEBUG,\n", - " \"evaluate_round %s: strategy sampled %s clients (out of %s)\",\n", - " server_round,\n", - " len(client_instructions),\n", - " self._client_manager.num_available(),\n", - " )\n", - "\n", - " # Collect `evaluate` results from all clients participating in this round\n", - " results, failures = evaluate_clients(\n", - " client_instructions,\n", - " max_workers=self.max_workers,\n", - " timeout=timeout,\n", - " )\n", - " log(\n", - " DEBUG,\n", - " \"evaluate_round %s received %s results and %s failures\",\n", - " server_round,\n", - " len(results),\n", - " len(failures),\n", - " )\n", - "\n", - " # Aggregate the evaluation results\n", - " aggregated_result: Tuple[\n", - " Optional[float],\n", - " Dict[str, Scalar],\n", - " ] = self.strategy.aggregate_evaluate(server_round, results, failures)\n", - "\n", - " loss_aggregated, metrics_aggregated = aggregated_result\n", - " return loss_aggregated, metrics_aggregated, (results, failures)\n", - "\n", - " def fit_round(\n", - " self,\n", - " server_round: int,\n", - " timeout: Optional[float],\n", - " ) -> Optional[\n", - " Tuple[\n", - " Optional[\n", - " Tuple[\n", - " Parameters,\n", - " Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[\n", - " Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]\n", - " ],\n", - " ],\n", - " ]\n", - " ],\n", - " Dict[str, Scalar],\n", - " FitResultsAndFailures,\n", - " ]\n", - " ]:\n", - " \"\"\"Perform a single round of federated averaging.\"\"\"\n", - "\n", - " # Get clients and their respective instructions from strategy\n", - " client_instructions = self.strategy.configure_fit(\n", - " server_round=server_round,\n", - " parameters=self.parameters,\n", - " client_manager=self._client_manager,\n", - " )\n", - "\n", - " if not client_instructions:\n", - " log(INFO, \"fit_round %s: no clients selected, cancel\", server_round)\n", - " return None\n", - " log(\n", - " DEBUG,\n", - " \"fit_round %s: strategy sampled %s clients (out of %s)\",\n", - " server_round,\n", - " len(client_instructions),\n", - " self._client_manager.num_available(),\n", - " )\n", - "\n", - " # Collect `fit` results from all clients participating in this round\n", - " results, failures = fit_clients(\n", - " client_instructions=client_instructions,\n", - " max_workers=self.max_workers,\n", - " timeout=timeout,\n", - " )\n", - "\n", - " log(\n", - " DEBUG,\n", - " \"fit_round %s received %s results and %s failures\",\n", - " server_round,\n", - " len(results),\n", - " len(failures),\n", - " )\n", - "\n", - " # Aggregate training results\n", - " NN_aggregated: Parameters\n", - " trees_aggregated: Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ]\n", - " metrics_aggregated: Dict[str, Scalar]\n", - " aggregated, metrics_aggregated = self.strategy.aggregate_fit(\n", - " server_round, results, failures\n", - " )\n", - " NN_aggregated, trees_aggregated = aggregated[0], aggregated[1]\n", - "\n", - " if type(trees_aggregated) is list:\n", - " print(\"Server side aggregated\", len(trees_aggregated), \"trees.\")\n", - " else:\n", - " print(\"Server side did not aggregate trees.\")\n", - "\n", - " return (\n", - " [NN_aggregated, trees_aggregated],\n", - " metrics_aggregated,\n", - " (results, failures),\n", - " )\n", - "\n", - " def _get_initial_parameters(\n", - " self, timeout: Optional[float]\n", - " ) -> Tuple[Parameters, Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]]:\n", - " \"\"\"Get initial parameters from one of the available clients.\"\"\"\n", - "\n", - " # Server-side parameter initialization\n", - " parameters: Optional[Parameters] = self.strategy.initialize_parameters(\n", - " client_manager=self._client_manager\n", - " )\n", - " if parameters is not None:\n", - " log(INFO, \"Using initial parameters provided by strategy\")\n", - " return parameters\n", - "\n", - " # Get initial parameters from one of the clients\n", - " log(INFO, \"Requesting initial parameters from one random client\")\n", - " random_client = self._client_manager.sample(1)[0]\n", - " ins = GetParametersIns(config={})\n", - " get_parameters_res_tree = random_client.get_parameters(ins=ins, timeout=timeout)\n", - " parameters = [get_parameters_res_tree[0].parameters, get_parameters_res_tree[1]]\n", - " log(INFO, \"Received initial parameters from one random client\")\n", - "\n", - " return parameters" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create server-side evaluation and experiment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "executionInfo": { - "elapsed": 35, - "status": "ok", - "timestamp": 1670363021676, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - } - }, - "outputs": [], - "source": [ - "def print_model_layers(model: nn.Module) -> None:\n", - " print(model)\n", - " for param_tensor in model.state_dict():\n", - " print(param_tensor, \"\\t\", model.state_dict()[param_tensor].size())\n", - "\n", - "\n", - "def serverside_eval(\n", - " server_round: int,\n", - " parameters: Tuple[\n", - " Parameters,\n", - " Union[\n", - " Tuple[XGBClassifier, int],\n", - " Tuple[XGBRegressor, int],\n", - " List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],\n", - " ],\n", - " ],\n", - " config: Dict[str, Scalar],\n", - " task_type: str,\n", - " testloader: DataLoader,\n", - " batch_size: int,\n", - " client_tree_num: int,\n", - " client_num: int,\n", - ") -> Tuple[float, Dict[str, float]]:\n", - " \"\"\"An evaluation function for centralized/serverside evaluation over the entire test set.\"\"\"\n", - " # device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", - " device = \"cpu\"\n", - " model = CNN()\n", - " # print_model_layers(model)\n", - "\n", - " model.set_weights(parameters_to_ndarrays(parameters[0]))\n", - " model.to(device)\n", - "\n", - " trees_aggregated = parameters[1]\n", - " testloader = tree_encoding_loader(\n", - " testloader, batch_size, trees_aggregated, client_tree_num, client_num\n", - " )\n", - " loss, result, _ = test(\n", - " task_type, model, testloader, device=device, log_progress=False\n", - " )\n", - "\n", - " if task_type == \"BINARY\":\n", - " print(\n", - " f\"Evaluation on the server: test_loss={loss:.4f}, test_accuracy={result:.4f}\"\n", - " )\n", - " return loss, {\"accuracy\": result}\n", - " elif task_type == \"REG\":\n", - " print(f\"Evaluation on the server: test_loss={loss:.4f}, test_mse={result:.4f}\")\n", - " return loss, {\"mse\": result}\n", - "\n", - "\n", - "def start_experiment(\n", - " task_type: str,\n", - " trainset: Dataset,\n", - " testset: Dataset,\n", - " num_rounds: int = 5,\n", - " client_tree_num: int = 50,\n", - " client_pool_size: int = 5,\n", - " num_iterations: int = 100,\n", - " fraction_fit: float = 1.0,\n", - " min_fit_clients: int = 2,\n", - " batch_size: int = 32,\n", - " val_ratio: float = 0.1,\n", - ") -> History:\n", - " client_resources = {\"num_cpus\": 0.5} # 2 clients per CPU\n", - "\n", - " # Partition the dataset into subsets reserved for each client.\n", - " # - 'val_ratio' controls the proportion of the (local) client reserved as a local test set\n", - " # (good for testing how the final model performs on the client's local unseen data)\n", - " trainloaders, valloaders, testloader = do_fl_partitioning(\n", - " trainset,\n", - " testset,\n", - " batch_size=\"whole\",\n", - " pool_size=client_pool_size,\n", - " val_ratio=val_ratio,\n", - " )\n", - " print(\n", - " f\"Data partitioned across {client_pool_size} clients\"\n", - " f\" and {val_ratio} of local dataset reserved for validation.\"\n", - " )\n", - "\n", - " # Configure the strategy\n", - " def fit_config(server_round: int) -> Dict[str, Scalar]:\n", - " print(f\"Configuring round {server_round}\")\n", - " return {\n", - " \"num_iterations\": num_iterations,\n", - " \"batch_size\": batch_size,\n", - " }\n", - "\n", - " # FedXgbNnAvg\n", - " strategy = FedXgbNnAvg(\n", - " fraction_fit=fraction_fit,\n", - " fraction_evaluate=fraction_fit if val_ratio > 0.0 else 0.0,\n", - " min_fit_clients=min_fit_clients,\n", - " min_evaluate_clients=min_fit_clients,\n", - " min_available_clients=client_pool_size, # all clients should be available\n", - " on_fit_config_fn=fit_config,\n", - " on_evaluate_config_fn=(lambda r: {\"batch_size\": batch_size}),\n", - " evaluate_fn=functools.partial(\n", - " serverside_eval,\n", - " task_type=task_type,\n", - " testloader=testloader,\n", - " batch_size=batch_size,\n", - " client_tree_num=client_tree_num,\n", - " client_num=client_num,\n", - " ),\n", - " accept_failures=False,\n", - " )\n", - "\n", - " print(\n", - " f\"FL experiment configured for {num_rounds} rounds with {client_pool_size} client in the pool.\"\n", - " )\n", - " print(\n", - " f\"FL round will proceed with {fraction_fit * 100}% of clients sampled, at least {min_fit_clients}.\"\n", - " )\n", - "\n", - " def client_fn(cid: str) -> fl.client.Client:\n", - " \"\"\"Creates a federated learning client\"\"\"\n", - " if val_ratio > 0.0 and val_ratio <= 1.0:\n", - " return FL_Client(\n", - " task_type,\n", - " trainloaders[int(cid)],\n", - " valloaders[int(cid)],\n", - " client_tree_num,\n", - " client_pool_size,\n", - " cid,\n", - " log_progress=False,\n", - " )\n", - " else:\n", - " return FL_Client(\n", - " task_type,\n", - " trainloaders[int(cid)],\n", - " None,\n", - " client_tree_num,\n", - " client_pool_size,\n", - " cid,\n", - " log_progress=False,\n", - " )\n", - "\n", - " # Start the simulation\n", - " history = fl.simulation.start_simulation(\n", - " client_fn=client_fn,\n", - " server=FL_Server(client_manager=SimpleClientManager(), strategy=strategy),\n", - " num_clients=client_pool_size,\n", - " client_resources=client_resources,\n", - " config=ServerConfig(num_rounds=num_rounds),\n", - " strategy=strategy,\n", - " )\n", - "\n", - " print(history)\n", - "\n", - " return history" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Start federated training and inference\n", - "#### High-level workflow: \n", - "#### At round 1, each client first builds their own local XGBoost tree, and sends to the server. The server aggregates all trees and sends to all clients. \n", - "#### After round 1, each client calculates every other client tree’s prediction results, and trains a convolutional neural network with 1D convolution kernel size == the number of XGBoost trees in the tree ensemble. \n", - "#### The sharing of privacy-sensitive information is not needed, and the learning rate (a hyperparameter for XGBoost) is learnable using 1D convolution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 624 - }, - "executionInfo": { - "elapsed": 7610, - "status": "error", - "timestamp": 1670363029252, - "user": { - "displayName": "Chenyang Ma", - "userId": "17975430055716133031" - }, - "user_tz": 0 - }, - "outputId": "ee2b7146-07ec-4f97-ba44-5b12b35bbeaf" - }, - "outputs": [], - "source": [ - "start_experiment(\n", - " task_type=task_type,\n", - " trainset=trainset,\n", - " testset=testset,\n", - " num_rounds=20,\n", - " client_tree_num=client_tree_num,\n", - " client_pool_size=client_num,\n", - " num_iterations=100,\n", - " batch_size=64,\n", - " fraction_fit=1.0,\n", - " min_fit_clients=1,\n", - " val_ratio=0.0,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "gpuClass": "premium", - "kernelspec": { - "display_name": "FedXGBoost", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index d9f795766f6d..f5871f1b44e4 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -23,7 +23,8 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: task_pb2.TaskIns( task_id="", # Do not set, will be created and set by the DriverAPI group_id="", - workload_id=workload_id, + run_id=run_id, + run_id=run_id, task=merge( task, task_pb2.Task( @@ -84,13 +85,13 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK driver.connect() -create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload( - req=driver_pb2.CreateWorkloadRequest() +create_run_res: driver_pb2.CreateRunResponse = driver.create_run( + req=driver_pb2.CreateRunRequest() ) # -------------------------------------------------------------------------- Driver SDK -workload_id = create_workload_res.workload_id -print(f"Created workload id {workload_id}") +run_id = create_run_res.run_id +print(f"Created run id {run_id}") history = History() for server_round in range(num_rounds): @@ -119,7 +120,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # loop and wait until enough client nodes are available. while True: # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id) + get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id) # ---------------------------------------------------------------------- Driver SDK get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( diff --git a/examples/simulation-pytorch/README.md b/examples/simulation-pytorch/README.md index 2fe8366cbc04..11b7a3364376 100644 --- a/examples/simulation-pytorch/README.md +++ b/examples/simulation-pytorch/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using PyTorch -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on either a single machine or a cluster of machines. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive on how Flower simulation works. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -41,7 +41,7 @@ poetry shell Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: ```shell -poetry run python3 -c "import flwr" +poetry run python -c "import flwr" ``` If you don't see any errors you're good to go! @@ -58,7 +58,7 @@ pip install -r requirements.txt ```bash # You can run the example without activating your environemnt -poetry run python3 sim.py +poetry run python sim.py # Or by first activating it poetry shell diff --git a/examples/simulation-pytorch/pyproject.toml b/examples/simulation-pytorch/pyproject.toml index 3b1cacf230f8..07918c0cd17c 100644 --- a/examples/simulation-pytorch/pyproject.toml +++ b/examples/simulation-pytorch/pyproject.toml @@ -11,5 +11,10 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } -torch = "1.13.1" -torchvision = "0.14.1" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +torch = "2.1.1" +torchvision = "0.16.1" + +[tool.poetry.group.dev.dependencies] +ipykernel = "^6.27.0" + diff --git a/examples/simulation-pytorch/requirements.txt b/examples/simulation-pytorch/requirements.txt index 78ac83101d5b..4dbecab3e546 100644 --- a/examples/simulation-pytorch/requirements.txt +++ b/examples/simulation-pytorch/requirements.txt @@ -1,3 +1,4 @@ flwr[simulation]>=1.0, <2.0 -torch==1.13.1 -torchvision==0.14.1 \ No newline at end of file +torch==2.1.1 +torchvision==0.16.1 +flwr-datasets[vision]>=0.0.2, <1.0.0 \ No newline at end of file diff --git a/examples/simulation-pytorch/sim.ipynb b/examples/simulation-pytorch/sim.ipynb index e708aa36542d..508630cf9422 100644 --- a/examples/simulation-pytorch/sim.ipynb +++ b/examples/simulation-pytorch/sim.ipynb @@ -21,7 +21,8 @@ "outputs": [], "source": [ "# depending on your shell, you might need to add `\\` before `[` and `]`.\n", - "!pip install -q flwr[simulation]" + "!pip install -q flwr[simulation]\n", + "!pip install flwr_datasets[vision]" ] }, { @@ -29,7 +30,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine]() in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." + "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.dev/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." ] }, { @@ -40,22 +41,7 @@ "\n", "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart-_ example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory.\n", "\n", - "In this tutorial we are going to use PyTorch, so let's install a recent version. In this tutorial we'll use a small model so using CPU only training will suffice (this will also prevent Colab from abruptly terminating your experiment if resource limits are exceeded)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "7192138a-8c87-4d9a-f726-af1038ad264c" - }, - "outputs": [], - "source": [ - "# Install Pytorch with CPU support. Please adjust this command for your platform or if you want to use a GPU\n", - "!pip install torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu" + "In this tutorial we are going to use PyTorch, it comes pre-installed in your Collab runtime so there is no need to installed it again. If you wouuld like to install another version, you can still do that in the same way other packages are installed via `!pip`" ] }, { @@ -63,7 +49,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We are going to install some other dependencies you are likely familiar with. We'll use these to make plots." + "We are going to install some other dependencies you are likely familiar with. Let's install `maplotlib` to plot our results at the end." ] }, { @@ -80,187 +66,6 @@ "!pip install matplotlib" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Centralised training: the old way of doing ML" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's begin by creating a simple (but complete) training loop as it is commonly done in centralised setups. Starting our tutorial in this way will allow us to very clearly identify which parts of a typical ML pipeline are common to both centralised and federated training.\n", - "\n", - "For this tutorial we'll design a image classification pipeline for [MNIST digits](https://en.wikipedia.org/wiki/MNIST_database) and using a simple CNN model as the network to train. The MNIST dataset is comprised of `28x28` greyscale images with digits from 0 to 9 (i.e. 10 classes in total)\n", - "\n", - "\n", - "## A dataset\n", - "\n", - "Let's begin by constructing the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# we naturally first need to import torch and torchvision\n", - "import torch\n", - "from torch.utils.data import DataLoader\n", - "from torchvision.transforms import ToTensor, Normalize, Compose\n", - "from torchvision.datasets import MNIST\n", - "\n", - "\n", - "def get_mnist(data_path: str = \"./data\"):\n", - " \"\"\"This function downloads the MNIST dataset into the `data_path`\n", - " directory if it is not there already. We construct the train/test\n", - " split by converting the images into tensors and normalizing them\"\"\"\n", - "\n", - " # transformation to convert images to tensors and apply normalization\n", - " tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n", - "\n", - " # prepare train and test set\n", - " trainset = MNIST(data_path, train=True, download=True, transform=tr)\n", - " testset = MNIST(data_path, train=False, download=True, transform=tr)\n", - "\n", - " return trainset, testset" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's run the code above and do some visualisations to understand better the data we are working with !" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainset, testset = get_mnist()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can have a quick overview of our datasets by just typing the object on the command line. For instance, below you can see that the `trainset` has 60k training examples and will use the transformation rule we defined above in `get_mnist()`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "f10b649f-3cee-4e86-c7ff-94bd1fd3e082" - }, - "outputs": [], - "source": [ - "trainset" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create a more insightful visualisation. First let's see the distribution over the labels by constructing a histogram. Then, let's visualise some training examples !" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 490 - }, - "outputId": "c8d0f4c0-60cd-4c58-bc91-3b061dae8046" - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "# construct histogram\n", - "all_labels = trainset.targets\n", - "num_possible_labels = len(\n", - " set(all_labels.numpy().tolist())\n", - ") # this counts unique labels (so it should be = 10)\n", - "plt.hist(all_labels, bins=num_possible_labels)\n", - "\n", - "# plot formatting\n", - "plt.xticks(range(num_possible_labels))\n", - "plt.grid()\n", - "plt.xlabel(\"Label\")\n", - "plt.ylabel(\"Number of images\")\n", - "plt.title(\"Class labels distribution for MNIST\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "import numpy as np\n", - "\n", - "\n", - "def visualise_n_random_examples(trainset_, n: int, verbose: bool = True):\n", - " # take n examples at random\n", - " idx = list(range(len(trainset_.data)))\n", - " random.shuffle(idx)\n", - " idx = idx[:n]\n", - " if verbose:\n", - " print(f\"will display images with idx: {idx}\")\n", - "\n", - " # construct canvas\n", - " num_cols = 8\n", - " num_rows = int(np.ceil(len(idx) / num_cols))\n", - " fig, axs = plt.subplots(figsize=(16, num_rows * 2), nrows=num_rows, ncols=num_cols)\n", - "\n", - " # display images on canvas\n", - " for c_i, i in enumerate(idx):\n", - " axs.flat[c_i].imshow(trainset_.data[i], cmap=\"gray\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's visualise 32 images from the dataset\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 715 - }, - "outputId": "4e0988a8-388d-4acf-882b-089e4ea887bf" - }, - "outputs": [], - "source": [ - "# it is likely that the plot this function will generate looks familiar to other plots you might have generated before\n", - "# or you might have encountered in other tutorials. So far, we aren't doing anything new, Federated Learning will start soon!\n", - "visualise_n_random_examples(trainset, n=32)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -278,8 +83,10 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader\n", "\n", "\n", "class Net(nn.Module):\n", @@ -319,28 +126,27 @@ "metadata": {}, "outputs": [], "source": [ - "def train(net, trainloader, optimizer, epochs, device):\n", + "def train(net, trainloader, optim, epochs, device: str):\n", " \"\"\"Train the network on the training set.\"\"\"\n", " criterion = torch.nn.CrossEntropyLoss()\n", " net.train()\n", " for _ in range(epochs):\n", - " for images, labels in trainloader:\n", - " images, labels = images.to(device), labels.to(device)\n", - " optimizer.zero_grad()\n", + " for batch in trainloader:\n", + " images, labels = batch[\"image\"].to(device), batch[\"label\"].to(device)\n", + " optim.zero_grad()\n", " loss = criterion(net(images), labels)\n", " loss.backward()\n", - " optimizer.step()\n", - " return net\n", + " optim.step()\n", "\n", "\n", - "def test(net, testloader, device):\n", + "def test(net, testloader, device: str):\n", " \"\"\"Validate the network on the entire test set.\"\"\"\n", " criterion = torch.nn.CrossEntropyLoss()\n", " correct, loss = 0, 0.0\n", " net.eval()\n", " with torch.no_grad():\n", - " for images, labels in testloader:\n", - " images, labels = images.to(device), labels.to(device)\n", + " for data in testloader:\n", + " images, labels = data[\"image\"].to(device), data[\"label\"].to(device)\n", " outputs = net(images)\n", " loss += criterion(outputs, labels).item()\n", " _, predicted = torch.max(outputs.data, 1)\n", @@ -370,7 +176,9 @@ "source": [ "## One Client, One Data Partition\n", "\n", - "To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the MNIST dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system." + "To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the MNIST dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system.\n", + "\n", + "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -379,94 +187,42 @@ "metadata": {}, "outputs": [], "source": [ - "from torch.utils.data import random_split\n", - "\n", - "\n", - "def prepare_dataset(num_partitions: int, batch_size: int, val_ratio: float = 0.1):\n", - " \"\"\"This function partitions the training set into N disjoint\n", - " subsets, each will become the local dataset of a client. This\n", - " function also subsequently partitions each training set partition\n", - " into train and validation. The test set is left intact and will\n", - " be used by the central server to asses the performance of the\n", - " global model.\"\"\"\n", - "\n", - " # get the MNIST datatset\n", - " trainset, testset = get_mnist()\n", - "\n", - " # split trainset into `num_partitions` trainsets\n", - " num_images = len(trainset) // num_partitions\n", - "\n", - " partition_len = [num_images] * num_partitions\n", + "from datasets import Dataset\n", + "from flwr_datasets import FederatedDataset\n", + "from datasets.utils.logging import disable_progress_bar\n", "\n", - " trainsets = random_split(\n", - " trainset, partition_len, torch.Generator().manual_seed(2023)\n", - " )\n", - "\n", - " # create dataloaders with train+val support\n", - " trainloaders = []\n", - " valloaders = []\n", - " for trainset_ in trainsets:\n", - " num_total = len(trainset_)\n", - " num_val = int(val_ratio * num_total)\n", - " num_train = num_total - num_val\n", - "\n", - " for_train, for_val = random_split(\n", - " trainset_, [num_train, num_val], torch.Generator().manual_seed(2023)\n", - " )\n", - "\n", - " trainloaders.append(\n", - " DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2)\n", - " )\n", - " valloaders.append(\n", - " DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2)\n", - " )\n", - "\n", - " # create dataloader for the test set\n", - " testloader = DataLoader(testset, batch_size=128)\n", + "# Let's set a simulation involving a total of 100 clients\n", + "NUM_CLIENTS = 100\n", "\n", - " return trainloaders, valloaders, testloader" + "# Download MNIST dataset and partition the \"train\" partition (so one can be assigned to each client)\n", + "mnist_fds = FederatedDataset(dataset=\"mnist\", partitioners={\"train\": NUM_CLIENTS})\n", + "# Let's keep the test set as is, and use it to evaluate the global model on the server\n", + "centralized_testset = mnist_fds.load_full(\"test\")" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "Let's create 100 partitions and extract some statistics from one partition\n" + "Let's create a function that returns a set of transforms to apply to our images" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 508 - }, - "outputId": "0f53ca81-cb55-46ef-c8e0-4e19a4f060b2" - }, + "metadata": {}, "outputs": [], "source": [ - "NUM_CLIENTS = 100\n", - "\n", - "trainloaders, valloaders, testloader = prepare_dataset(\n", - " num_partitions=NUM_CLIENTS, batch_size=32\n", - ")\n", + "from torchvision.transforms import ToTensor, Normalize, Compose\n", "\n", - "# first partition\n", - "train_partition = trainloaders[0].dataset\n", "\n", - "# count data points\n", - "partition_indices = train_partition.indices\n", - "print(f\"number of images: {len(partition_indices)}\")\n", + "def apply_transforms(batch):\n", + " \"\"\"Get transformation for MNIST dataset\"\"\"\n", "\n", - "# visualise histogram\n", - "plt.hist(train_partition.dataset.dataset.targets[partition_indices], bins=10)\n", - "plt.grid()\n", - "plt.xticks(range(10))\n", - "plt.xlabel(\"Label\")\n", - "plt.ylabel(\"Number of images\")\n", - "plt.title(\"Class labels distribution for MNIST\")" + " # transformation to convert images to tensors and apply normalization\n", + " transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n", + " batch[\"image\"] = [transforms(img) for img in batch[\"image\"]]\n", + " return batch" ] }, { @@ -474,9 +230,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As you can see, the histogram of this partition is a bit different from the one we obtained at the beginning where we took the entire dataset into consideration. Because our data partitions are artificially constructed by sampling the MNIST dataset in an IID fashion, our Federated Learning example will not face sever _data heterogeneity_ issues (which is a fairly [active research topic](https://arxiv.org/abs/1912.04977)).\n", - "\n", - "Let's next define how our FL clients will behave\n", + "Let's next define how our FL clients will behave.\n", "\n", "## Defining a Flower Client\n", "\n", @@ -521,16 +275,15 @@ "from collections import OrderedDict\n", "from typing import Dict, List, Tuple\n", "\n", - "import torch\n", "from flwr.common import NDArrays, Scalar\n", "\n", "\n", "class FlowerClient(fl.client.NumPyClient):\n", - " def __init__(self, trainloader, vallodaer) -> None:\n", + " def __init__(self, trainloader, valloader) -> None:\n", " super().__init__()\n", "\n", " self.trainloader = trainloader\n", - " self.valloader = vallodaer\n", + " self.valloader = valloader\n", " self.model = Net(num_classes=10)\n", " # Determine device\n", " self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", @@ -613,7 +366,7 @@ "metadata": {}, "outputs": [], "source": [ - "def get_evaluate_fn(testloader):\n", + "def get_evaluate_fn(centralized_testset: Dataset):\n", " \"\"\"This is a function that returns a function. The returned\n", " function (i.e. `evaluate_fn`) will be executed by the strategy\n", " at the end of each round to evaluate the stat of the global\n", @@ -636,20 +389,15 @@ " state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n", " model.load_state_dict(state_dict, strict=True)\n", "\n", + " # Apply transform to dataset\n", + " testset = centralized_testset.with_transform(apply_transforms)\n", + "\n", + " testloader = DataLoader(testset, batch_size=50)\n", " # call test\n", " loss, accuracy = test(model, testloader, device)\n", " return loss, {\"accuracy\": accuracy}\n", "\n", - " return evaluate_fn\n", - "\n", - "\n", - "# now we can define the strategy\n", - "# strategy = fl.server.strategy.FedAvg(\n", - "# fraction_fit=0.1,\n", - "# fraction_evaluate=0.1,\n", - "# min_available_clients=100,\n", - "# evaluate_fn=get_evaluate_fn(testloader), # Even this is not required\n", - "# )" + " return evaluate_fn" ] }, { @@ -707,14 +455,9 @@ "strategy = fl.server.strategy.FedAvg(\n", " fraction_fit=0.1, # Sample 10% of available clients for training\n", " fraction_evaluate=0.05, # Sample 5% of available clients for evaluation\n", - " min_fit_clients=10, # Never sample less than 10 clients for training\n", - " min_evaluate_clients=5, # Never sample less than 5 clients for evaluation\n", - " min_available_clients=int(\n", - " NUM_CLIENTS * 0.75\n", - " ), # Wait until at least 75 clients are available\n", " on_fit_config_fn=fit_config,\n", " evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics\n", - " evaluate_fn=get_evaluate_fn(testloader), # global evaluation function\n", + " evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function\n", ")" ] }, @@ -737,18 +480,41 @@ "metadata": {}, "outputs": [], "source": [ - "def generate_client_fn(trainloaders, valloaders):\n", - " def client_fn(cid: str):\n", - " \"\"\"Returns a FlowerClient containing the cid-th data partition\"\"\"\n", + "from torch.utils.data import DataLoader\n", + "\n", + "\n", + "def get_client_fn(dataset: FederatedDataset):\n", + " \"\"\"Return a function to construct a client.\n", + "\n", + " The VirtualClientEngine will execute this function whenever a client is sampled by\n", + " the strategy to participate.\n", + " \"\"\"\n", "\n", - " return FlowerClient(\n", - " trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]\n", + " def client_fn(cid: str) -> fl.client.Client:\n", + " \"\"\"Construct a FlowerClient with its own dataset partition.\"\"\"\n", + "\n", + " # Let's get the partition corresponding to the i-th client\n", + " client_dataset = dataset.load_partition(int(cid), \"train\")\n", + "\n", + " # Now let's split it into train (90%) and validation (10%)\n", + " client_dataset_splits = client_dataset.train_test_split(test_size=0.1)\n", + "\n", + " trainset = client_dataset_splits[\"train\"]\n", + " valset = client_dataset_splits[\"test\"]\n", + "\n", + " # Now we apply the transform to each batch.\n", + " trainloader = DataLoader(\n", + " trainset.with_transform(apply_transforms), batch_size=32, shuffle=True\n", " )\n", + " valloader = DataLoader(valset.with_transform(apply_transforms), batch_size=32)\n", + "\n", + " # Create and return client\n", + " return FlowerClient(trainloader, valloader)\n", "\n", " return client_fn\n", "\n", "\n", - "client_fn_callback = generate_client_fn(trainloaders, valloaders)" + "client_fn_callback = get_client_fn(mnist_fds)" ] }, { @@ -774,6 +540,8 @@ "# client needs exclusive access to these many resources in order to run\n", "client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n", "\n", + "# Let's disable tqdm progress bar in the main thread (used by the server)\n", + "disable_progress_bar()\n", "\n", "history = fl.simulation.start_simulation(\n", " client_fn=client_fn_callback, # a callback to construct a client\n", @@ -781,6 +549,9 @@ " config=fl.server.ServerConfig(num_rounds=10), # let's run for 10 rounds\n", " strategy=strategy, # the strategy that will orchestrate the whole FL pipeline\n", " client_resources=client_resources,\n", + " actor_kwargs={\n", + " \"on_actor_init_fn\": disable_progress_bar # disable tqdm on each actor/process spawning virtual clients\n", + " },\n", ")" ] }, @@ -806,6 +577,8 @@ }, "outputs": [], "source": [ + "import matplotlib.pyplot as plt\n", + "\n", "print(f\"{history.metrics_centralized = }\")\n", "\n", "global_accuracy_centralised = history.metrics_centralized[\"accuracy\"]\n", @@ -817,6 +590,27 @@ "plt.xlabel(\"Round\")\n", "plt.title(\"MNIST - IID - 100 clients with 10 clients per round\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Congratulations! With that, you built a Flower client, customized it's instantiation through the `client_fn`, customized the server-side execution through a `FedAvg` strategy configured for this workload, and started a simulation with 100 clients (each holding their own individual partition of the MNIST dataset).\n", + "\n", + "Next, you can continue to explore more advanced Flower topics:\n", + "\n", + "- Deploy server and clients on different machines using `start_server` and `start_client`\n", + "- Customize the server-side execution through custom strategies\n", + "- Customize the client-side execution through `config` dictionaries\n", + "\n", + "Get all resources you need!\n", + "\n", + "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", + "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", + "\n", + "Don't forget to join our Slack channel: https://flower.dev/join-slack/\n" + ] } ], "metadata": { @@ -825,10 +619,11 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/examples/simulation-pytorch/sim.py b/examples/simulation-pytorch/sim.py index 5adfca744591..68d9426e83ab 100644 --- a/examples/simulation-pytorch/sim.py +++ b/examples/simulation-pytorch/sim.py @@ -3,15 +3,17 @@ from typing import Dict, Tuple, List import torch -import torchvision -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader import flwr as fl from flwr.common import Metrics from flwr.common.typing import Scalar -from utils import Net, train, test, get_mnist +from datasets import Dataset +from datasets.utils.logging import disable_progress_bar +from flwr_datasets import FederatedDataset +from utils import Net, train, test, apply_transforms parser = argparse.ArgumentParser(description="Flower Simulation with PyTorch") @@ -78,18 +80,28 @@ def evaluate(self, parameters, config): return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} -def get_client_fn(train_partitions, val_partitions): +def get_client_fn(dataset: FederatedDataset): """Return a function to construct a client. - The VirtualClientEngine will exectue this function whenever a client is sampled by + The VirtualClientEngine will execute this function whenever a client is sampled by the strategy to participate. """ def client_fn(cid: str) -> fl.client.Client: """Construct a FlowerClient with its own dataset partition.""" - # Extract partition for client with id = cid - trainset, valset = train_partitions[int(cid)], val_partitions[int(cid)] + # Let's get the partition corresponding to the i-th client + client_dataset = dataset.load_partition(int(cid), "train") + + # Now let's split it into train (90%) and validation (10%) + client_dataset_splits = client_dataset.train_test_split(test_size=0.1) + + trainset = client_dataset_splits["train"] + valset = client_dataset_splits["test"] + + # Now we apply the transform to each batch. + trainset = trainset.with_transform(apply_transforms) + valset = valset.with_transform(apply_transforms) # Create and return client return FlowerClient(trainset, valset) @@ -113,40 +125,6 @@ def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]): model.load_state_dict(state_dict, strict=True) -def prepare_dataset(): - """Download and partitions the MNIST dataset.""" - - # Get the MNIST dataset - trainset, testset = get_mnist() - - # Split trainset into `num_partitions` trainsets - num_images = len(trainset) // NUM_CLIENTS - partition_len = [num_images] * NUM_CLIENTS - - trainsets = random_split( - trainset, partition_len, torch.Generator().manual_seed(2023) - ) - - val_ratio = 0.1 - - # Create dataloaders with train+val support - train_partitions = [] - val_partitions = [] - for trainset_ in trainsets: - num_total = len(trainset_) - num_val = int(val_ratio * num_total) - num_train = num_total - num_val - - for_train, for_val = random_split( - trainset_, [num_train, num_val], torch.Generator().manual_seed(2023) - ) - - train_partitions.append(for_train) - val_partitions.append(for_val) - - return train_partitions, val_partitions, testset - - def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: """Aggregation function for (federated) evaluation metrics, i.e. those returned by the client's evaluate() method.""" @@ -159,7 +137,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: def get_evaluate_fn( - testset: torchvision.datasets.CIFAR10, + centralized_testset: Dataset, ): """Return an evaluation function for centralized evaluation.""" @@ -175,6 +153,12 @@ def evaluate( set_params(model, parameters) model.to(device) + # Apply transform to dataset + testset = centralized_testset.with_transform(apply_transforms) + + # Disable tqdm for dataset preprocessing + disable_progress_bar() + testloader = DataLoader(testset, batch_size=50) loss, accuracy = test(model, testloader, device=device) @@ -187,8 +171,9 @@ def main(): # Parse input arguments args = parser.parse_args() - # Download CIFAR-10 dataset and partition it - trainsets, valsets, testset = prepare_dataset() + # Download MNIST dataset and partition it + mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) + centralized_testset = mnist_fds.load_full("test") # Configure the strategy strategy = fl.server.strategy.FedAvg( @@ -201,7 +186,7 @@ def main(): ), # Wait until at least 75 clients are available on_fit_config_fn=fit_config, evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics - evaluate_fn=get_evaluate_fn(testset), # Global evaluation function + evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function ) # Resources to be assigned to each virtual client @@ -212,11 +197,14 @@ def main(): # Start simulation fl.simulation.start_simulation( - client_fn=get_client_fn(trainsets, valsets), + client_fn=get_client_fn(mnist_fds), num_clients=NUM_CLIENTS, client_resources=client_resources, config=fl.server.ServerConfig(num_rounds=args.num_rounds), strategy=strategy, + actor_kwargs={ + "on_actor_init_fn": disable_progress_bar # disable tqdm on each actor/process spawning virtual clients + }, ) diff --git a/examples/simulation-pytorch/utils.py b/examples/simulation-pytorch/utils.py index fff6bb490930..01f63cc94ba3 100644 --- a/examples/simulation-pytorch/utils.py +++ b/examples/simulation-pytorch/utils.py @@ -3,7 +3,13 @@ import torch.nn.functional as F from torchvision.transforms import ToTensor, Normalize, Compose -from torchvision.datasets import MNIST + + +# transformation to convert images to tensors and apply normalization +def apply_transforms(batch): + transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) + batch["image"] = [transforms(img) for img in batch["image"]] + return batch # Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz') @@ -33,8 +39,8 @@ def train(net, trainloader, optim, epochs, device: str): criterion = torch.nn.CrossEntropyLoss() net.train() for _ in range(epochs): - for images, labels in trainloader: - images, labels = images.to(device), labels.to(device) + for batch in trainloader: + images, labels = batch["image"].to(device), batch["label"].to(device) optim.zero_grad() loss = criterion(net(images), labels) loss.backward() @@ -49,23 +55,10 @@ def test(net, testloader, device: str): net.eval() with torch.no_grad(): for data in testloader: - images, labels = data[0].to(device), data[1].to(device) + images, labels = data["image"].to(device), data["label"].to(device) outputs = net(images) loss += criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() accuracy = correct / len(testloader.dataset) return loss, accuracy - - -def get_mnist(data_path: str = "./data"): - """Download MNIST and apply transform.""" - - # transformation to convert images to tensors and apply normalization - tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) - - # prepare train and test set - trainset = MNIST(data_path, train=True, download=True, transform=tr) - testset = MNIST(data_path, train=False, download=True, transform=tr) - - return trainset, testset diff --git a/examples/simulation-tensorflow/README.md b/examples/simulation-tensorflow/README.md index 61a6749a6bdf..f0d94f343d37 100644 --- a/examples/simulation-tensorflow/README.md +++ b/examples/simulation-tensorflow/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using TensorFlow/Keras -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on either a single machine of a cluster of machines. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive on how Flower simulation works. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -40,7 +40,7 @@ poetry shell Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: ```shell -poetry run python3 -c "import flwr" +poetry run python -c "import flwr" ``` If you don't see any errors you're good to go! @@ -57,7 +57,7 @@ pip install -r requirements.txt ```bash # You can run the example without activating your environemnt -poetry run python3 sim.py +poetry run python sim.py # Or by first activating it poetry shell diff --git a/examples/simulation-tensorflow/pyproject.toml b/examples/simulation-tensorflow/pyproject.toml index 4016c3da0da0..f2e7bd3006c0 100644 --- a/examples/simulation-tensorflow/pyproject.toml +++ b/examples/simulation-tensorflow/pyproject.toml @@ -11,5 +11,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow = {version = "^2.9.1, !=2.11.1", markers="platform_machine == 'x86_64'"} tensorflow-macos = {version = "^2.9.1, !=2.11.1", markers="sys_platform == 'darwin' and platform_machine == 'arm64'"} diff --git a/examples/simulation-tensorflow/requirements.txt b/examples/simulation-tensorflow/requirements.txt index 76e77f5ff9b8..bb69a87be1b4 100644 --- a/examples/simulation-tensorflow/requirements.txt +++ b/examples/simulation-tensorflow/requirements.txt @@ -1,3 +1,4 @@ flwr[simulation]>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" diff --git a/examples/simulation-tensorflow/sim.ipynb b/examples/simulation-tensorflow/sim.ipynb index 559dcf3170a3..575b437018f3 100644 --- a/examples/simulation-tensorflow/sim.ipynb +++ b/examples/simulation-tensorflow/sim.ipynb @@ -17,7 +17,8 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q flwr[\"simulation\"] tensorflow" + "!pip install -q flwr[\"simulation\"] tensorflow\n", + "!pip install -q flwr_datasets[\"vision\"]" ] }, { @@ -49,7 +50,6 @@ "metadata": {}, "outputs": [], "source": [ - "import math\n", "from typing import Dict, List, Tuple\n", "\n", "import tensorflow as tf\n", @@ -58,6 +58,9 @@ "from flwr.common import Metrics\n", "from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth\n", "\n", + "from datasets import Dataset\n", + "from flwr_datasets import FederatedDataset\n", + "\n", "VERBOSE = 0\n", "NUM_CLIENTS = 100" ] @@ -113,28 +116,24 @@ "outputs": [], "source": [ "class FlowerClient(fl.client.NumPyClient):\n", - " def __init__(self, x_train, y_train, x_val, y_val) -> None:\n", + " def __init__(self, trainset, valset) -> None:\n", " # Create model\n", " self.model = get_model()\n", - " self.x_train, self.y_train = x_train, y_train\n", - " self.x_val, self.y_val = x_val, y_val\n", + " self.trainset = trainset\n", + " self.valset = valset\n", "\n", " def get_parameters(self, config):\n", " return self.model.get_weights()\n", "\n", " def fit(self, parameters, config):\n", " self.model.set_weights(parameters)\n", - " self.model.fit(\n", - " self.x_train, self.y_train, epochs=1, batch_size=32, verbose=VERBOSE\n", - " )\n", - " return self.model.get_weights(), len(self.x_train), {}\n", + " self.model.fit(self.trainset, epochs=1, verbose=VERBOSE)\n", + " return self.model.get_weights(), len(self.trainset), {}\n", "\n", " def evaluate(self, parameters, config):\n", " self.model.set_weights(parameters)\n", - " loss, acc = self.model.evaluate(\n", - " self.x_val, self.y_val, batch_size=64, verbose=VERBOSE\n", - " )\n", - " return loss, len(self.x_val), {\"accuracy\": acc}" + " loss, acc = self.model.evaluate(self.valset, verbose=VERBOSE)\n", + " return loss, len(self.valset), {\"accuracy\": acc}" ] }, { @@ -155,8 +154,6 @@ "We now define four auxiliary functions for this example (note the last two are entirely optional):\n", "* `get_client_fn()`: Is a function that returns another function. The returned `client_fn` will be executed by Flower's VirtualClientEngine each time a new _virtual_ client (i.e. a client that is simulated in a Python process) needs to be spawn. When are virtual clients spawned? Each time the strategy samples them to do either `fit()` (i.e. train the global model on the local data of a particular client) or `evaluate()` (i.e. evaluate the global model on the validation set of a given client).\n", "\n", - "* `partition_mnist()`: A utility function that downloads the MNIST dataset and partitions it into `NUM_CLIENT` disjoint sets. The resulting list of dataset partitions will be passed to `get_client_fn()` so a client can be constructed by passing it its corresponding dataset partition. There are multiple ways of partitioning a dataset, but in this example we keep things simple. For larger dataset, you might want to pre-partition your dataset before running your Flower experiment and, potentially, store these partition into your files system or a database. In this way, your `FlowerClient` objects can retrieve their data directly when doing either `fit()` or `evaluate()`.\n", - "\n", "* `weighted_average()`: This is an optional function to pass to the strategy. It will be executed after an evaluation round (i.e. when client run `evaluate()`) and will aggregate the metrics clients return. In this example, we use this function to compute the weighted average accuracy of clients doing `evaluate()`.\n", "\n", "* `get_evaluate_fn()`: This is again a function that returns another function. The returned function will be executed by the strategy at the end of a `fit()` round and after a new global model has been obtained after aggregation. This is an optional argument for Flower strategies. In this example, we use the whole MNIST test set to perform this server-side evaluation." @@ -168,42 +165,35 @@ "metadata": {}, "outputs": [], "source": [ - "def get_client_fn(dataset_partitions):\n", - " \"\"\"Return a function to be executed by the VirtualClientEngine in order to construct\n", - " a client.\"\"\"\n", + "def get_client_fn(dataset: FederatedDataset):\n", + " \"\"\"Return a function to construct a client.\n", + "\n", + " The VirtualClientEngine will execute this function whenever a client is sampled by\n", + " the strategy to participate.\n", + " \"\"\"\n", "\n", " def client_fn(cid: str) -> fl.client.Client:\n", " \"\"\"Construct a FlowerClient with its own dataset partition.\"\"\"\n", "\n", " # Extract partition for client with id = cid\n", - " x_train, y_train = dataset_partitions[int(cid)]\n", - " # Use 10% of the client's training data for validation\n", - " split_idx = math.floor(len(x_train) * 0.9)\n", - " x_train_cid, y_train_cid = (\n", - " x_train[:split_idx],\n", - " y_train[:split_idx],\n", + " client_dataset = dataset.load_partition(int(cid), \"train\")\n", + "\n", + " # Now let's split it into train (90%) and validation (10%)\n", + " client_dataset_splits = client_dataset.train_test_split(test_size=0.1)\n", + "\n", + " trainset = client_dataset_splits[\"train\"].to_tf_dataset(\n", + " columns=\"image\", label_cols=\"label\", batch_size=32\n", + " )\n", + " valset = client_dataset_splits[\"test\"].to_tf_dataset(\n", + " columns=\"image\", label_cols=\"label\", batch_size=64\n", " )\n", - " x_val_cid, y_val_cid = x_train[split_idx:], y_train[split_idx:]\n", "\n", " # Create and return client\n", - " return FlowerClient(x_train_cid, y_train_cid, x_val_cid, y_val_cid)\n", + " return FlowerClient(trainset, valset)\n", "\n", " return client_fn\n", "\n", "\n", - "def partition_mnist():\n", - " \"\"\"Download and partitions the MNIST dataset.\"\"\"\n", - " (x_train, y_train), testset = tf.keras.datasets.mnist.load_data()\n", - " partitions = []\n", - " # We keep all partitions equal-sized in this example\n", - " partition_size = math.floor(len(x_train) / NUM_CLIENTS)\n", - " for cid in range(NUM_CLIENTS):\n", - " # Split dataset into non-overlapping NUM_CLIENT partitions\n", - " idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size\n", - " partitions.append((x_train[idx_from:idx_to] / 255.0, y_train[idx_from:idx_to]))\n", - " return partitions, testset\n", - "\n", - "\n", "def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:\n", " \"\"\"Aggregation function for (federated) evaluation metrics, i.e. those returned by\n", " the client's evaluate() method.\"\"\"\n", @@ -215,9 +205,8 @@ " return {\"accuracy\": sum(accuracies) / sum(examples)}\n", "\n", "\n", - "def get_evaluate_fn(testset):\n", - " \"\"\"Return an evaluation function for server-side (i.e. centralized) evaluation.\"\"\"\n", - " x_test, y_test = testset\n", + "def get_evaluate_fn(testset: Dataset):\n", + " \"\"\"Return an evaluation function for server-side (i.e. centralised) evaluation.\"\"\"\n", "\n", " # The `evaluate` function will be called after every round by the strategy\n", " def evaluate(\n", @@ -227,7 +216,7 @@ " ):\n", " model = get_model() # Construct the model\n", " model.set_weights(parameters) # Update model with the latest parameters\n", - " loss, accuracy = model.evaluate(x_test, y_test, verbose=VERBOSE)\n", + " loss, accuracy = model.evaluate(testset, verbose=VERBOSE)\n", " return loss, {\"accuracy\": accuracy}\n", "\n", " return evaluate" @@ -241,7 +230,9 @@ "\n", "The function `start_simulation` accepts a number of arguments, amongst them the `client_fn` used to create `FlowerClient` instances, the number of clients to simulate `num_clients`, the number of rounds `num_rounds`, and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, *Federated Averaging* (FedAvg).\n", "\n", - "Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation." + "Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation.\n", + "\n", + "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -253,8 +244,13 @@ "# Enable GPU growth in your main process\n", "enable_tf_gpu_growth()\n", "\n", - "# Create dataset partitions (needed if your dataset is not pre-partitioned)\n", - "partitions, testset = partition_mnist()\n", + "# Download MNIST dataset and partition it\n", + "mnist_fds = FederatedDataset(dataset=\"mnist\", partitioners={\"train\": NUM_CLIENTS})\n", + "# Get the whole test set for centralised evaluation\n", + "centralized_testset = mnist_fds.load_full(\"test\").to_tf_dataset(\n", + " columns=\"image\", label_cols=\"label\", batch_size=64\n", + ")\n", + "\n", "\n", "# Create FedAvg strategy\n", "strategy = fl.server.strategy.FedAvg(\n", @@ -266,7 +262,7 @@ " NUM_CLIENTS * 0.75\n", " ), # Wait until at least 75 clients are available\n", " evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics\n", - " evaluate_fn=get_evaluate_fn(testset), # global evaluation function\n", + " evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function\n", ")\n", "\n", "# With a dictionary, you tell Flower's VirtualClientEngine that each\n", @@ -275,7 +271,7 @@ "\n", "# Start simulation\n", "history = fl.simulation.start_simulation(\n", - " client_fn=get_client_fn(partitions),\n", + " client_fn=get_client_fn(mnist_fds),\n", " num_clients=NUM_CLIENTS,\n", " config=fl.server.ServerConfig(num_rounds=10),\n", " strategy=strategy,\n", @@ -290,7 +286,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can then use the resturned History object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtainined at the end of each round (including at the very beginning of the experiment) for the global model. This is want the function evaluate_fn() that we passed to the strategy reports." + "You can then use the resturned History object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtainined at the end of each round (including at the very beginning of the experiment) for the global model. This is want the function `evaluate_fn()` that we passed to the strategy reports." ] }, { @@ -323,7 +319,15 @@ "\n", "- Deploy server and clients on different machines using `start_server` and `start_client`\n", "- Customize the server-side execution through custom strategies\n", - "- Customize the client-side execution through `config` dictionaries" + "- Customize the client-side execution through `config` dictionaries\n", + "\n", + "Get all resources you need!\n", + "\n", + "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", + "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", + "\n", + "Don't forget to join our Slack channel: https://flower.dev/join-slack/" ] } ], @@ -333,11 +337,11 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3.8.12 ('.venv': poetry)", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/examples/simulation-tensorflow/sim.py b/examples/simulation-tensorflow/sim.py index 15f7097dc439..490e25fe8c8d 100644 --- a/examples/simulation-tensorflow/sim.py +++ b/examples/simulation-tensorflow/sim.py @@ -9,6 +9,8 @@ from flwr.common import Metrics from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth +from datasets import Dataset +from flwr_datasets import FederatedDataset # Make TensorFlow logs less verbose os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -34,28 +36,24 @@ class FlowerClient(fl.client.NumPyClient): - def __init__(self, x_train, y_train, x_val, y_val) -> None: + def __init__(self, trainset, valset) -> None: # Create model self.model = get_model() - self.x_train, self.y_train = x_train, y_train - self.x_val, self.y_val = x_val, y_val + self.trainset = trainset + self.valset = valset def get_parameters(self, config): return self.model.get_weights() def fit(self, parameters, config): self.model.set_weights(parameters) - self.model.fit( - self.x_train, self.y_train, epochs=1, batch_size=32, verbose=VERBOSE - ) - return self.model.get_weights(), len(self.x_train), {} + self.model.fit(self.trainset, epochs=1, verbose=VERBOSE) + return self.model.get_weights(), len(self.trainset), {} def evaluate(self, parameters, config): self.model.set_weights(parameters) - loss, acc = self.model.evaluate( - self.x_val, self.y_val, batch_size=64, verbose=VERBOSE - ) - return loss, len(self.x_val), {"accuracy": acc} + loss, acc = self.model.evaluate(self.valset, verbose=VERBOSE) + return loss, len(self.valset), {"accuracy": acc} def get_model(): @@ -72,10 +70,10 @@ def get_model(): return model -def get_client_fn(dataset_partitions): - """Return a function to construc a client. +def get_client_fn(dataset: FederatedDataset): + """Return a function to construct a client. - The VirtualClientEngine will exectue this function whenever a client is sampled by + The VirtualClientEngine will execute this function whenever a client is sampled by the strategy to participate. """ @@ -83,34 +81,24 @@ def client_fn(cid: str) -> fl.client.Client: """Construct a FlowerClient with its own dataset partition.""" # Extract partition for client with id = cid - x_train, y_train = dataset_partitions[int(cid)] - # Use 10% of the client's training data for validation - split_idx = math.floor(len(x_train) * 0.9) - x_train_cid, y_train_cid = ( - x_train[:split_idx], - y_train[:split_idx], + client_dataset = dataset.load_partition(int(cid), "train") + + # Now let's split it into train (90%) and validation (10%) + client_dataset_splits = client_dataset.train_test_split(test_size=0.1) + + trainset = client_dataset_splits["train"].to_tf_dataset( + columns="image", label_cols="label", batch_size=32 + ) + valset = client_dataset_splits["test"].to_tf_dataset( + columns="image", label_cols="label", batch_size=64 ) - x_val_cid, y_val_cid = x_train[split_idx:], y_train[split_idx:] # Create and return client - return FlowerClient(x_train_cid, y_train_cid, x_val_cid, y_val_cid) + return FlowerClient(trainset, valset) return client_fn -def partition_mnist(): - """Download and partitions the MNIST dataset.""" - (x_train, y_train), testset = tf.keras.datasets.mnist.load_data() - partitions = [] - # We keep all partitions equal-sized in this example - partition_size = math.floor(len(x_train) / NUM_CLIENTS) - for cid in range(NUM_CLIENTS): - # Split dataset into non-overlapping NUM_CLIENT partitions - idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size - partitions.append((x_train[idx_from:idx_to] / 255.0, y_train[idx_from:idx_to])) - return partitions, testset - - def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: """Aggregation function for (federated) evaluation metrics. @@ -124,9 +112,8 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: return {"accuracy": sum(accuracies) / sum(examples)} -def get_evaluate_fn(testset): +def get_evaluate_fn(testset: Dataset): """Return an evaluation function for server-side (i.e. centralised) evaluation.""" - x_test, y_test = testset # The `evaluate` function will be called after every round by the strategy def evaluate( @@ -136,7 +123,7 @@ def evaluate( ): model = get_model() # Construct the model model.set_weights(parameters) # Update model with the latest parameters - loss, accuracy = model.evaluate(x_test, y_test, verbose=VERBOSE) + loss, accuracy = model.evaluate(testset, verbose=VERBOSE) return loss, {"accuracy": accuracy} return evaluate @@ -146,8 +133,12 @@ def main() -> None: # Parse input arguments args = parser.parse_args() - # Create dataset partitions (needed if your dataset is not pre-partitioned) - partitions, testset = partition_mnist() + # Download MNIST dataset and partition it + mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) + # Get the whole test set for centralised evaluation + centralized_testset = mnist_fds.load_full("test").to_tf_dataset( + columns="image", label_cols="label", batch_size=64 + ) # Create FedAvg strategy strategy = fl.server.strategy.FedAvg( @@ -159,7 +150,7 @@ def main() -> None: NUM_CLIENTS * 0.75 ), # Wait until at least 75 clients are available evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics - evaluate_fn=get_evaluate_fn(testset), # global evaluation function + evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function ) # With a dictionary, you tell Flower's VirtualClientEngine that each @@ -171,7 +162,7 @@ def main() -> None: # Start simulation fl.simulation.start_simulation( - client_fn=get_client_fn(partitions), + client_fn=get_client_fn(mnist_fds), num_clients=NUM_CLIENTS, config=fl.server.ServerConfig(num_rounds=args.num_rounds), strategy=strategy, diff --git a/examples/sklearn-logreg-mnist/README.md b/examples/sklearn-logreg-mnist/README.md index 79ed63a64233..ee3cdfc9768e 100644 --- a/examples/sklearn-logreg-mnist/README.md +++ b/examples/sklearn-logreg-mnist/README.md @@ -1,7 +1,7 @@ # Flower Example using scikit-learn This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system. It will help you understand how to adapt Flower for use with `scikit-learn`. -Running this example in itself is quite easy. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. ## Project Setup @@ -57,18 +57,24 @@ Afterwards you are ready to start the Flower server as well as the clients. You poetry run python3 server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two or more terminals and run the following command in each: + +Start client 1 in the first terminal: ```shell -poetry run python3 client.py +python3 client.py --node-id 0 # or any integer in {0-9} ``` -Alternatively you can run all of it in one shell as follows: +Start client 2 in the second terminal: ```shell -poetry run python3 server.py & -poetry run python3 client.py & -poetry run python3 client.py +python3 client.py --node-id 1 # or any integer in {0-9} +``` + +Alternatively, you can run all of it in one shell as follows: + +```bash +bash run.sh ``` You will see that Flower is starting a federated training. diff --git a/examples/sklearn-logreg-mnist/client.py b/examples/sklearn-logreg-mnist/client.py index dbf0f2f462a7..a5fcaba87409 100644 --- a/examples/sklearn-logreg-mnist/client.py +++ b/examples/sklearn-logreg-mnist/client.py @@ -1,19 +1,35 @@ +import argparse import warnings -import flwr as fl -import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss +import flwr as fl import utils +from flwr_datasets import FederatedDataset if __name__ == "__main__": - # Load MNIST dataset from https://www.openml.org/d/554 - (X_train, y_train), (X_test, y_test) = utils.load_mnist() + N_CLIENTS = 10 + + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--node-id", + type=int, + choices=range(0, N_CLIENTS), + required=True, + help="Specifies the artificial data partition", + ) + args = parser.parse_args() + partition_id = args.node_id + + # Load the partition data + fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS}) - # Split train set into 10 partitions and randomly use one for training. - partition_id = np.random.choice(10) - (X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id] + dataset = fds.load_partition(partition_id, "train").with_format("numpy") + X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] + # Split the on edge data: 80% train, 20% test + X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] + y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] # Create LogisticRegression Model model = LogisticRegression( diff --git a/examples/sklearn-logreg-mnist/pyproject.toml b/examples/sklearn-logreg-mnist/pyproject.toml index 7c13b3f3d492..8ea49fe187a2 100644 --- a/examples/sklearn-logreg-mnist/pyproject.toml +++ b/examples/sklearn-logreg-mnist/pyproject.toml @@ -13,7 +13,7 @@ authors = [ [tool.poetry.dependencies] python = "^3.8" -flwr = "^1.0.0" +flwr = ">=1.0,<2.0" # flwr = { path = "../../", develop = true } # Development +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } scikit-learn = "^1.1.1" -openml = "^0.12.2" diff --git a/examples/sklearn-logreg-mnist/requirements.txt b/examples/sklearn-logreg-mnist/requirements.txt index eec2e1a3c4bd..50da9ace3630 100644 --- a/examples/sklearn-logreg-mnist/requirements.txt +++ b/examples/sklearn-logreg-mnist/requirements.txt @@ -1,4 +1,4 @@ -flwr~=1.4.0 +flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 numpy~=1.21.1 -openml~=0.13.1 scikit_learn~=1.2.2 diff --git a/examples/sklearn-logreg-mnist/run.sh b/examples/sklearn-logreg-mnist/run.sh index c64f362086aa..48cee1b41b74 100755 --- a/examples/sklearn-logreg-mnist/run.sh +++ b/examples/sklearn-logreg-mnist/run.sh @@ -1,15 +1,17 @@ #!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 1`; do +for i in $(seq 0 1); do echo "Starting client $i" - python client.py & + python client.py --node-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes -trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +trap 'trap - SIGTERM && kill -- -$$' SIGINT SIGTERM # Wait for all background processes to complete wait diff --git a/examples/sklearn-logreg-mnist/server.py b/examples/sklearn-logreg-mnist/server.py index 77e7a89dd668..8541100c3a26 100644 --- a/examples/sklearn-logreg-mnist/server.py +++ b/examples/sklearn-logreg-mnist/server.py @@ -4,6 +4,8 @@ from sklearn.linear_model import LogisticRegression from typing import Dict +from flwr_datasets import FederatedDataset + def fit_round(server_round: int) -> Dict: """Send round number to client.""" @@ -14,7 +16,9 @@ def get_evaluate_fn(model: LogisticRegression): """Return an evaluation function for server-side evaluation.""" # Load test data here to avoid the overhead of doing it in `evaluate` itself - _, (X_test, y_test) = utils.load_mnist() + fds = FederatedDataset(dataset="mnist", partitioners={"train": 10}) + dataset = fds.load_full("test").with_format("numpy") + X_test, y_test = dataset["image"].reshape((len(dataset), -1)), dataset["label"] # The `evaluate` function will be called after every round def evaluate(server_round, parameters: fl.common.NDArrays, config): diff --git a/examples/sklearn-logreg-mnist/utils.py b/examples/sklearn-logreg-mnist/utils.py index 6a6d6c12ac73..b279a0d1a4b3 100644 --- a/examples/sklearn-logreg-mnist/utils.py +++ b/examples/sklearn-logreg-mnist/utils.py @@ -1,16 +1,11 @@ -from typing import Tuple, Union, List import numpy as np from sklearn.linear_model import LogisticRegression -import openml -XY = Tuple[np.ndarray, np.ndarray] -Dataset = Tuple[XY, XY] -LogRegParams = Union[XY, Tuple[np.ndarray]] -XYList = List[XY] +from flwr.common import NDArrays -def get_model_parameters(model: LogisticRegression) -> LogRegParams: - """Returns the paramters of a sklearn LogisticRegression model.""" +def get_model_parameters(model: LogisticRegression) -> NDArrays: + """Returns the parameters of a sklearn LogisticRegression model.""" if model.fit_intercept: params = [ model.coef_, @@ -23,9 +18,7 @@ def get_model_parameters(model: LogisticRegression) -> LogRegParams: return params -def set_model_params( - model: LogisticRegression, params: LogRegParams -) -> LogisticRegression: +def set_model_params(model: LogisticRegression, params: NDArrays) -> LogisticRegression: """Sets the parameters of a sklean LogisticRegression model.""" model.coef_ = params[0] if model.fit_intercept: @@ -47,32 +40,3 @@ def set_initial_params(model: LogisticRegression): model.coef_ = np.zeros((n_classes, n_features)) if model.fit_intercept: model.intercept_ = np.zeros((n_classes,)) - - -def load_mnist() -> Dataset: - """Loads the MNIST dataset using OpenML. - - OpenML dataset link: https://www.openml.org/d/554 - """ - mnist_openml = openml.datasets.get_dataset(554) - Xy, _, _, _ = mnist_openml.get_data(dataset_format="array") - X = Xy[:, :-1] # the last column contains labels - y = Xy[:, -1] - # First 60000 samples consist of the train set - x_train, y_train = X[:60000], y[:60000] - x_test, y_test = X[60000:], y[60000:] - return (x_train, y_train), (x_test, y_test) - - -def shuffle(X: np.ndarray, y: np.ndarray) -> XY: - """Shuffle X and y.""" - rng = np.random.default_rng() - idx = rng.permutation(len(X)) - return X[idx], y[idx] - - -def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList: - """Split X and y into a number of partitions.""" - return list( - zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions)) - ) diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index da002a10d301..11c4c3f9a08b 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -9,6 +9,7 @@ It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/mai - Customised number of partitions. - Customised partitioner type (uniform, linear, square, exponential). - Centralised/distributed evaluation. +- Bagging/cyclic training methods. ## Project Setup @@ -26,7 +27,8 @@ This will create a new directory called `xgboost-comprehensive` containing the f -- client.py <- Defines the client-side logic -- dataset.py <- Defines the functions of data loading and partitioning -- utils.py <- Defines the arguments parser for clients and server --- run.sh <- Commands to run experiments +-- run_bagging.sh <- Commands to run bagging experiments +-- run_cyclic.sh <- Commands to run cyclic experiments -- pyproject.toml <- Example dependencies (if you use Poetry) -- requirements.txt <- Example dependencies ``` @@ -60,24 +62,31 @@ pip install -r requirements.txt ## Run Federated Learning with XGBoost and Flower -The included `run.sh` will start the Flower server (using `server.py`) with centralised evaluation, +We have two scripts to run bagging and cyclic (client-by-client) experiments. +The included `run_bagging.sh` or `run_cyclic.sh` will start the Flower server (using `server.py`), sleep for 15 seconds to ensure that the server is up, and then start 5 Flower clients (using `client.py`) with a small subset of the data from exponential partition distribution. You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +poetry run ./run_bagging.sh ``` -The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. +Or + +```shell +poetry run ./run_cyclic.sh +``` + +The script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `poetry run python3 server.py --pool-size=N --num-clients-per-round=N` -and `poetry run python3 client.py --node-id=NODE_ID --num-partitions=N` for as many clients as you want, +You can also manually run `poetry run python3 server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N` +and `poetry run python3 client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N` for as many clients as you want, but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). @@ -86,6 +95,8 @@ and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.htm ### Expected Experimental Results +#### Bagging aggregation experiment + ![](_static/xgboost_flower_auc.png) The figure above shows the centralised tested AUC performance over FL rounds on 4 experimental settings. diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index a37edac32648..ff7a4adf7977 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -101,11 +101,16 @@ def _local_boost(self): for i in range(num_local_round): self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) - # Extract the last N=num_local_round trees for sever aggregation - bst = self.bst[ - self.bst.num_boosted_rounds() - - num_local_round : self.bst.num_boosted_rounds() - ] + # Bagging: extract the last N=num_local_round trees for sever aggregation + # Cyclic: return the entire model + bst = ( + self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + if args.train_method == "bagging" + else self.bst + ) return bst diff --git a/examples/xgboost-comprehensive/run.sh b/examples/xgboost-comprehensive/run_bagging.sh similarity index 100% rename from examples/xgboost-comprehensive/run.sh rename to examples/xgboost-comprehensive/run_bagging.sh diff --git a/examples/xgboost-comprehensive/run_cyclic.sh b/examples/xgboost-comprehensive/run_cyclic.sh new file mode 100755 index 000000000000..47e09fd8faef --- /dev/null +++ b/examples/xgboost-comprehensive/run_cyclic.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python3 server.py --train-method=cyclic --pool-size=5 --num-rounds=100 & +sleep 15 # Sleep for 15s to give the server enough time to start + +for i in `seq 0 4`; do + echo "Starting client $i" + python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 3da7e8d9865c..1cf4ba79fa50 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -1,4 +1,5 @@ -from typing import Dict +import warnings +from typing import Dict, List, Optional from logging import INFO import xgboost as xgb @@ -7,13 +8,21 @@ from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset from flwr.server.strategy import FedXgbBagging +from flwr.server.strategy import FedXgbCyclic +from flwr.server.client_proxy import ClientProxy +from flwr.server.criterion import Criterion +from flwr.server.client_manager import SimpleClientManager from utils import server_args_parser, BST_PARAMS from dataset import resplit, transform_dataset_to_dmatrix +warnings.filterwarnings("ignore", category=UserWarning) + + # Parse arguments for experimental settings args = server_args_parser() +train_method = args.train_method pool_size = args.pool_size num_rounds = args.num_rounds num_clients_per_round = args.num_clients_per_round @@ -80,23 +89,72 @@ def evaluate_fn( return evaluate_fn +class CyclicClientManager(SimpleClientManager): + """Provides a cyclic client selection rule.""" + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + # Return all available clients + return [self.clients[cid] for cid in available_cids] + + # Define strategy -strategy = FedXgbBagging( - evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None, - fraction_fit=(float(num_clients_per_round) / pool_size), - min_fit_clients=num_clients_per_round, - min_available_clients=pool_size, - min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, - fraction_evaluate=1.0 if not centralised_eval else 0.0, - on_evaluate_config_fn=eval_config, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation - if not centralised_eval - else None, -) +if train_method == "bagging": + # Bagging training + strategy = FedXgbBagging( + evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None, + fraction_fit=(float(num_clients_per_round) / pool_size), + min_fit_clients=num_clients_per_round, + min_available_clients=pool_size, + min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, + fraction_evaluate=1.0 if not centralised_eval else 0.0, + on_evaluate_config_fn=eval_config, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation + if not centralised_eval + else None, + ) +else: + # Cyclic training + strategy = FedXgbCyclic( + fraction_fit=1.0, + min_available_clients=pool_size, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=eval_config, + ) # Start Flower server fl.server.start_server( server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=num_rounds), strategy=strategy, + client_manager=CyclicClientManager() if train_method == "cyclic" else None, ) diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py index 000def370752..8acdbbb88a7e 100644 --- a/examples/xgboost-comprehensive/utils.py +++ b/examples/xgboost-comprehensive/utils.py @@ -17,6 +17,13 @@ def client_args_parser(): """Parse arguments to define experimental settings on client side.""" parser = argparse.ArgumentParser() + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) parser.add_argument( "--num-partitions", default=10, type=int, help="Number of partitions." ) @@ -56,6 +63,13 @@ def server_args_parser(): """Parse arguments to define experimental settings on server side.""" parser = argparse.ArgumentParser() + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) parser.add_argument( "--pool-size", default=2, type=int, help="Number of total clients." ) diff --git a/pyproject.toml b/pyproject.toml index 2349d554a409..1ccdc72666f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,8 +61,8 @@ flower-client = "flwr.client:run_client" python = "^3.8" # Mandatory dependencies numpy = "^1.21.0" -grpcio = "^1.48.2,!=1.52.0" -protobuf = "^3.19.0" +grpcio = "^1.60.0" +protobuf = "^4.25.2" cryptography = "^41.0.2" pycryptodome = "^3.18.0" iterators = "^0.0.2" @@ -81,21 +81,21 @@ rest = ["requests", "starlette", "uvicorn"] [tool.poetry.group.dev.dependencies] types-dataclasses = "==0.6.6" types-protobuf = "==3.19.18" -types-requests = "==2.31.0.2" -types-setuptools = "==68.2.0.0" +types-requests = "==2.31.0.10" +types-setuptools = "==69.0.0.20240115" clang-format = "==17.0.4" -isort = "==5.12.0" +isort = "==5.13.2" black = { version = "==23.10.1", extras = ["jupyter"] } docformatter = "==1.7.5" -mypy = "==1.6.1" -pylint = "==2.13.9" +mypy = "==1.8.0" +pylint = "==3.0.3" flake8 = "==5.0.4" -pytest = "==7.4.3" +pytest = "==7.4.4" pytest-cov = "==4.1.0" pytest-watch = "==4.2.0" -grpcio-tools = "==1.48.2" +grpcio-tools = "==1.60.0" mypy-protobuf = "==3.2.0" -jupyterlab = "==4.0.8" +jupyterlab = "==4.0.9" rope = "==1.11.0" semver = "==3.0.2" sphinx = "==6.2.1" @@ -109,7 +109,7 @@ furo = "==2023.9.10" sphinx-reredirects = "==0.1.3" nbsphinx = "==0.9.3" nbstripout = "==0.6.1" -ruff = "==0.1.4" +ruff = "==0.1.9" sphinx-argparse = "==0.4.0" pipreqs = "==0.4.13" mdformat-gfm = "==0.3.5" @@ -120,7 +120,8 @@ twine = "==4.0.2" pyroma = "==4.2" check-wheel-contents = "==0.4.0" GitPython = "==3.1.32" -licensecheck = "==2023.5.1" +PyGithub = "==2.1.1" +licensecheck = "==2024" [tool.isort] line_length = 88 @@ -136,7 +137,7 @@ line-length = 88 target-version = ["py38", "py39", "py310", "py311"] [tool.pylint."MESSAGES CONTROL"] -disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +disable = "duplicate-code,too-few-public-methods,useless-import-alias" [tool.pytest.ini_options] minversion = "6.2" @@ -183,7 +184,7 @@ target-version = "py38" line-length = 88 select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] -ignore = ["B024", "B027"] +ignore = ["B024", "B027", "D205", "D209"] exclude = [ ".bzr", ".direnv", diff --git a/src/docker/base/Dockerfile b/src/docker/base/Dockerfile new file mode 100644 index 000000000000..9cd410ba3fb5 --- /dev/null +++ b/src/docker/base/Dockerfile @@ -0,0 +1,45 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. + +ARG UBUNTU_VERSION=22.04 +FROM ubuntu:$UBUNTU_VERSION as base + +ENV DEBIAN_FRONTEND noninteractive +# Send stdout and stderr stream directly to the terminal. Ensures that no +# output is retained in a buffer if the application crashes. +ENV PYTHONUNBUFFERED 1 +# Typically, bytecode is created on the first invocation to speed up following invocation. +# However, in Docker we only make a single invocation (when we start the container). +# Therefore, we can disable bytecode writing. +ENV PYTHONDONTWRITEBYTECODE 1 +# Ensure that python encoding is always UTF-8. +ENV PYTHONIOENCODING UTF-8 +ENV LANG C.UTF-8 +ENV LC_ALL C.UTF-8 + +# Install system dependencies +RUN apt-get update \ + && apt-get -y --no-install-recommends install \ + clang-format git unzip ca-certificates openssh-client liblzma-dev \ + build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev wget\ + libsqlite3-dev curl llvm libncursesw5-dev xz-utils tk-dev libxml2-dev \ + libxmlsec1-dev libffi-dev liblzma-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install PyEnv and Python +ARG PYTHON_VERSION +ENV PYENV_ROOT /root/.pyenv +ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH +# https://github.com/hadolint/hadolint/wiki/DL4006 +SHELL ["/bin/bash", "-o", "pipefail", "-c"] +RUN curl -L https://github.com/pyenv/pyenv-installer/raw/master/bin/pyenv-installer | bash +RUN pyenv install ${PYTHON_VERSION} \ + && pyenv global ${PYTHON_VERSION} \ + && pyenv rehash + +# Install specific version of pip +ARG PIP_VERSION +RUN python -m pip install --no-cache-dir pip==$PIP_VERSION + +# Install specific version of setuptools +ARG SETUPTOOLS_VERSION +RUN python -m pip install --no-cache-dir setuptools==$SETUPTOOLS_VERSION diff --git a/src/docker/client/Dockerfile b/src/docker/client/Dockerfile new file mode 100644 index 000000000000..0755a7989281 --- /dev/null +++ b/src/docker/client/Dockerfile @@ -0,0 +1,8 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. + +ARG BASE_REPOSITORY=flwr/base +ARG BASE_IMAGE_TAG +FROM $BASE_REPOSITORY:$BASE_IMAGE_TAG + +ARG FLWR_VERSION +RUN python -m pip install -U --no-cache-dir flwr[rest]==${FLWR_VERSION} diff --git a/src/docker/server/Dockerfile b/src/docker/server/Dockerfile index 37f1ea91f33c..c42246b16104 100644 --- a/src/docker/server/Dockerfile +++ b/src/docker/server/Dockerfile @@ -1,51 +1,8 @@ # Copyright 2023 Flower Labs GmbH. All Rights Reserved. -ARG UBUNTU_VERSION=22.04 -FROM ubuntu:$UBUNTU_VERSION as base - -ENV DEBIAN_FRONTEND noninteractive -# Send stdout and stderr stream directly to the terminal. Ensures that no -# output is retained in a buffer if the application crashes. -ENV PYTHONUNBUFFERED 1 -# Typically, bytecode is created on the first invocation to speed up following invocation. -# However, in Docker we only make a single invocation (when we start the container). -# Therefore, we can disable bytecode writing. -ENV PYTHONDONTWRITEBYTECODE 1 -# Ensure that python encoding is always UTF-8. -ENV PYTHONIOENCODING UTF-8 -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 - -# Install system dependencies -RUN apt-get update \ - && apt-get -y --no-install-recommends install \ - clang-format git unzip ca-certificates openssh-client liblzma-dev \ - build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev wget\ - libsqlite3-dev curl llvm libncursesw5-dev xz-utils tk-dev libxml2-dev \ - libxmlsec1-dev libffi-dev liblzma-dev \ - && rm -rf /var/lib/apt/lists/* - -# Install PyEnv and Python -ARG PYTHON_VERSION -ENV PYENV_ROOT /root/.pyenv -ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH -# https://github.com/hadolint/hadolint/wiki/DL4006 -SHELL ["/bin/bash", "-o", "pipefail", "-c"] -RUN curl -L https://github.com/pyenv/pyenv-installer/raw/master/bin/pyenv-installer | bash -RUN pyenv install ${PYTHON_VERSION} \ - && pyenv global ${PYTHON_VERSION} \ - && pyenv rehash - -# Install specific version of pip -ARG PIP_VERSION -RUN python -m pip install --no-cache-dir pip==$PIP_VERSION - -# Install specific version of setuptools -ARG SETUPTOOLS_VERSION -RUN python -m pip install --no-cache-dir setuptools==$SETUPTOOLS_VERSION - -# Server image -FROM base as server +ARG BASE_REPOSITORY=flwr/base +ARG BASE_IMAGE_TAG=py3.11-ubuntu22.04 +FROM $BASE_REPOSITORY:$BASE_IMAGE_TAG as server WORKDIR /app ARG FLWR_VERSION diff --git a/src/kotlin/flwr/src/main/AndroidManifest.xml b/src/kotlin/flwr/src/main/AndroidManifest.xml index 8bdb7e14b389..3cb3262db448 100644 --- a/src/kotlin/flwr/src/main/AndroidManifest.xml +++ b/src/kotlin/flwr/src/main/AndroidManifest.xml @@ -1,4 +1,5 @@ - + + diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index eb948217a4de..bc0062c4a51f 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -21,8 +21,8 @@ import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; service Driver { - // Request workload_id - rpc CreateWorkload(CreateWorkloadRequest) returns (CreateWorkloadResponse) {} + // Request run_id + rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {} // Return a set of nodes rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {} @@ -34,12 +34,12 @@ service Driver { rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {} } -// CreateWorkload -message CreateWorkloadRequest {} -message CreateWorkloadResponse { sint64 workload_id = 1; } +// CreateRun +message CreateRunRequest {} +message CreateRunResponse { sint64 run_id = 1; } // GetNodes messages -message GetNodesRequest { sint64 workload_id = 1; } +message GetNodesRequest { sint64 run_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } // PushTaskIns messages diff --git a/src/proto/flwr/proto/recordset.proto b/src/proto/flwr/proto/recordset.proto new file mode 100644 index 000000000000..8e2e5d60b6db --- /dev/null +++ b/src/proto/flwr/proto/recordset.proto @@ -0,0 +1,70 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +message DoubleList { repeated double vals = 1; } +message Sint64List { repeated sint64 vals = 1; } +message BoolList { repeated bool vals = 1; } +message StringList { repeated string vals = 1; } +message BytesList { repeated bytes vals = 1; } + +message Array { + string dtype = 1; + repeated int32 shape = 2; + string stype = 3; + bytes data = 4; +} + +message MetricsRecordValue { + oneof value { + // Single element + double double = 1; + sint64 sint64 = 2; + + // List types + DoubleList double_list = 21; + Sint64List sint64_list = 22; + } +} + +message ConfigsRecordValue { + oneof value { + // Single element + double double = 1; + sint64 sint64 = 2; + bool bool = 3; + string string = 4; + bytes bytes = 5; + + // List types + DoubleList double_list = 21; + Sint64List sint64_list = 22; + BoolList bool_list = 23; + StringList string_list = 24; + BytesList bytes_list = 25; + } +} + +message ParametersRecord { + repeated string data_keys = 1; + repeated Array data_values = 2; +} + +message MetricsRecord { map data = 1; } + +message ConfigsRecord { map data = 1; } diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 2205ef2815c8..20dd5a3aa6c8 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -18,6 +18,7 @@ syntax = "proto3"; package flwr.proto; import "flwr/proto/node.proto"; +import "flwr/proto/recordset.proto"; import "flwr/proto/transport.proto"; message Task { @@ -27,7 +28,8 @@ message Task { string delivered_at = 4; string ttl = 5; repeated string ancestry = 6; - SecureAggregation sa = 7; + string task_type = 7; + SecureAggregation sa = 8; ServerMessage legacy_server_message = 101 [ deprecated = true ]; ClientMessage legacy_client_message = 102 [ deprecated = true ]; @@ -36,24 +38,18 @@ message Task { message TaskIns { string task_id = 1; string group_id = 2; - sint64 workload_id = 3; + sint64 run_id = 3; Task task = 4; } message TaskRes { string task_id = 1; string group_id = 2; - sint64 workload_id = 3; + sint64 run_id = 3; Task task = 4; } message Value { - message DoubleList { repeated double vals = 1; } - message Sint64List { repeated sint64 vals = 1; } - message BoolList { repeated bool vals = 1; } - message StringList { repeated string vals = 1; } - message BytesList { repeated bytes vals = 1; } - oneof value { // Single element double double = 1; diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 7ce7d51d3d4b..ae5beeae07d6 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -35,18 +35,20 @@ TRANSPORT_TYPES, ) from flwr.common.logger import log, warn_experimental_feature -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -from .flower import load_callable +from .flower import load_flower_callable from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message +from .node_state import NodeState from .numpy_client import NumPyClient -from .workload_state import WorkloadState def run_client() -> None: """Run Flower client.""" + event(EventType.RUN_CLIENT_ENTER) + log(INFO, "Long-running Flower client starting") args = _parse_args_client().parse_args() @@ -72,24 +74,25 @@ def run_client() -> None: print(args.root_certificates) print(args.server) - print(args.callable_dir) + print(args.dir) print(args.callable) - callable_dir = args.callable_dir + callable_dir = args.dir if callable_dir is not None: sys.path.insert(0, callable_dir) def _load() -> Flower: - flower: Flower = load_callable(args.callable) + flower: Flower = load_flower_callable(args.callable) return flower - return start_client( + _start_client_internal( server_address=args.server, - load_callable_fn=_load, + load_flower_callable_fn=_load, transport="grpc-rere", # Only root_certificates=root_certificates, insecure=args.insecure, ) + event(EventType.RUN_CLIENT_LEAVE) def _parse_args_client() -> argparse.ArgumentParser: @@ -98,6 +101,10 @@ def _parse_args_client() -> argparse.ArgumentParser: description="Start a long-running Flower client", ) + parser.add_argument( + "callable", + help="For example: `client:flower` or `project.package.module:wrapper.flower`", + ) parser.add_argument( "--insecure", action="store_true", @@ -117,13 +124,10 @@ def _parse_args_client() -> argparse.ArgumentParser: help="Server address", ) parser.add_argument( - "--callable", - help="For example: `client:flower` or `project.package.module:wrapper.flower`", - ) - parser.add_argument( - "--callable-dir", + "--dir", default="", - help="Add specified directory to the PYTHONPATH and load callable from there." + help="Add specified directory to the PYTHONPATH and load Flower " + "callable from there." " Default: current working directory.", ) @@ -134,10 +138,12 @@ def _check_actionable_client( client: Optional[Client], client_fn: Optional[ClientFn] ) -> None: if client_fn is None and client is None: - raise Exception("Both `client_fn` and `client` are `None`, but one is required") + raise ValueError( + "Both `client_fn` and `client` are `None`, but one is required" + ) if client_fn is not None and client is not None: - raise Exception( + raise ValueError( "Both `client_fn` and `client` are provided, but only one is allowed" ) @@ -146,10 +152,10 @@ def _check_actionable_client( # pylint: disable=too-many-branches # pylint: disable=too-many-locals # pylint: disable=too-many-statements +# pylint: disable=too-many-arguments def start_client( *, server_address: str, - load_callable_fn: Optional[Callable[[], Flower]] = None, client_fn: Optional[ClientFn] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, @@ -165,8 +171,6 @@ def start_client( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. - load_callable_fn : Optional[Callable[[], Flower]] (default: None) - ... client_fn : Optional[ClientFn] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] @@ -223,11 +227,73 @@ class `flwr.client.Client` (default: None) >>> ) """ event(EventType.START_CLIENT_ENTER) + _start_client_internal( + server_address=server_address, + load_flower_callable_fn=None, + client_fn=client_fn, + client=client, + grpc_max_message_length=grpc_max_message_length, + root_certificates=root_certificates, + insecure=insecure, + transport=transport, + ) + event(EventType.START_CLIENT_LEAVE) + +# pylint: disable=import-outside-toplevel +# pylint: disable=too-many-branches +# pylint: disable=too-many-locals +# pylint: disable=too-many-statements +def _start_client_internal( + *, + server_address: str, + load_flower_callable_fn: Optional[Callable[[], Flower]] = None, + client_fn: Optional[ClientFn] = None, + client: Optional[Client] = None, + grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, + root_certificates: Optional[Union[bytes, str]] = None, + insecure: Optional[bool] = None, + transport: Optional[str] = None, +) -> None: + """Start a Flower client node which connects to a Flower server. + + Parameters + ---------- + server_address : str + The IPv4 or IPv6 address of the server. If the Flower + server runs on the same machine on port 8080, then `server_address` + would be `"[::]:8080"`. + load_flower_callable_fn : Optional[Callable[[], Flower]] (default: None) + A function that can be used to load a `Flower` callable instance. + client_fn : Optional[ClientFn] + A callable that instantiates a Client. (default: None) + client : Optional[flwr.client.Client] + An implementation of the abstract base + class `flwr.client.Client` (default: None) + grpc_max_message_length : int (default: 536_870_912, this equals 512MB) + The maximum length of gRPC messages that can be exchanged with the + Flower server. The default should be sufficient for most models. + Users who train very large models might need to increase this + value. Note that the Flower server needs to be started with the + same value (see `flwr.server.start_server`), otherwise it will not + know about the increased limit and block larger messages. + root_certificates : Optional[Union[bytes, str]] (default: None) + The PEM-encoded root certificates as a byte string or a path string. + If provided, a secure connection using the certificates will be + established to an SSL-enabled Flower server. + insecure : bool (default: True) + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + transport : Optional[str] (default: None) + Configure the transport layer. Allowed values: + - 'grpc-bidi': gRPC, bidirectional streaming + - 'grpc-rere': gRPC, request-response (experimental) + - 'rest': HTTP (experimental) + """ if insecure is None: insecure = root_certificates is None - if load_callable_fn is None: + if load_flower_callable_fn is None: _check_actionable_client(client, client_fn) if client_fn is None: @@ -236,7 +302,7 @@ def single_client_factory( cid: str, # pylint: disable=unused-argument ) -> Client: if client is None: # Added this to keep mypy happy - raise Exception( + raise ValueError( "Both `client_fn` and `client` are `None`, but one is required" ) return client # Always return the same instance @@ -246,16 +312,18 @@ def single_client_factory( def _load_app() -> Flower: return Flower(client_fn=client_fn) - load_callable_fn = _load_app + load_flower_callable_fn = _load_app else: - warn_experimental_feature("`load_callable_fn`") + warn_experimental_feature("`load_flower_callable_fn`") - # At this point, only `load_callable_fn` should be used + # At this point, only `load_flower_callable_fn` should be used # Both `client` and `client_fn` must not be used directly # Initialize connection context manager connection, address = _init_connection(transport, server_address) + node_state = NodeState() + while True: sleep_duration: int = 0 with connection( @@ -283,16 +351,25 @@ def _load_app() -> Flower: send(task_res) break + # Register state + node_state.register_runstate(run_id=task_ins.run_id) + # Load app - app: Flower = load_callable_fn() + app: Flower = load_flower_callable_fn() # Handle task message fwd_msg: Fwd = Fwd( task_ins=task_ins, - state=WorkloadState(state={}), + state=node_state.retrieve_runstate(run_id=task_ins.run_id), ) bwd_msg: Bwd = app(fwd=fwd_msg) + # Update node state + node_state.update_runstate( + run_id=bwd_msg.task_res.run_id, + run_state=bwd_msg.state, + ) + # Send send(bwd_msg.task_res) @@ -311,8 +388,6 @@ def _load_app() -> Flower: ) time.sleep(sleep_duration) - event(EventType.START_CLIENT_LEAVE) - def start_numpy_client( *, diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 7ef6410debad..56d6308a0fe2 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -41,19 +41,19 @@ class PlainClient(Client): def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit(self, ins: FitIns) -> FitRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate(self, ins: EvaluateIns) -> EvaluateRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() class NeedsWrappingClient(NumPyClient): @@ -61,23 +61,23 @@ class NeedsWrappingClient(NumPyClient): def get_properties(self, config: Config) -> Dict[str, Scalar]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, config: Config) -> NDArrays: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit( self, parameters: NDArrays, config: Config ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate( self, parameters: NDArrays, config: Config ) -> Tuple[float, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def test_to_client_with_client() -> None: diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 280e0a8ca989..54b53296fd2f 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -19,7 +19,7 @@ from abc import ABC -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.common import ( Code, EvaluateIns, @@ -38,7 +38,7 @@ class Client(ABC): """Abstract base class for Flower clients.""" - state: WorkloadState + state: RunState def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. @@ -141,12 +141,12 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) - def get_state(self) -> WorkloadState: - """Get the workload state from this client.""" + def get_state(self) -> RunState: + """Get the run state from this client.""" return self.state - def set_state(self, state: WorkloadState) -> None: - """Apply a workload state to this client.""" + def set_state(self, state: RunState) -> None: + """Apply a run state to this client.""" self.state = state def to_client(self) -> Client: diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index 41b4d676df43..c39b89b31da3 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -117,16 +117,16 @@ def fit( update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)] if "dpfedavg_clip_norm" not in config: - raise Exception("Clipping threshold not supplied by the server.") + raise KeyError("Clipping threshold not supplied by the server.") if not isinstance(config["dpfedavg_clip_norm"], float): - raise Exception("Clipping threshold should be a floating point value.") + raise TypeError("Clipping threshold should be a floating point value.") # Clipping update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"]) if "dpfedavg_noise_stddev" in config: if not isinstance(config["dpfedavg_noise_stddev"], float): - raise Exception( + raise TypeError( "Scale of noise to be added should be a floating point value." ) # Noising @@ -138,7 +138,7 @@ def fit( # Calculating value of norm indicator bit, required for adaptive clipping if "dpfedavg_adaptive_clip_enabled" in config: if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool): - raise Exception( + raise TypeError( "dpfedavg_adaptive_clip_enabled should be a boolean-valued flag." ) metrics["dpfedavg_norm_bit"] = not clipped diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 5b083ee11b9f..535f096e5866 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -16,10 +16,11 @@ import importlib -from typing import cast +from typing import List, Optional, cast from flwr.client.message_handler.message_handler import handle -from flwr.client.typing import Bwd, ClientFn, Fwd +from flwr.client.middleware.utils import make_ffn +from flwr.client.typing import Bwd, ClientFn, Fwd, Layer class Flower: @@ -51,28 +52,30 @@ class Flower: def __init__( self, client_fn: ClientFn, # Only for backward compatibility + layers: Optional[List[Layer]] = None, ) -> None: - self.client_fn = client_fn + # Create wrapper function for `handle` + def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name + task_res, state_updated = handle( + client_fn=client_fn, + state=fwd.state, + task_ins=fwd.task_ins, + ) + return Bwd(task_res=task_res, state=state_updated) + + # Wrap middleware layers around the wrapped handle function + self._call = make_ffn(ffn, layers if layers is not None else []) def __call__(self, fwd: Fwd) -> Bwd: """.""" - # Execute the task - task_res, state_updated = handle( - client_fn=self.client_fn, - state=fwd.state, - task_ins=fwd.task_ins, - ) - return Bwd( - task_res=task_res, - state=state_updated, - ) + return self._call(fwd) class LoadCallableError(Exception): """.""" -def load_callable(module_attribute_str: str) -> Flower: +def load_flower_callable(module_attribute_str: str) -> Flower: """Load the `Flower` object specified in a module attribute string. The module/attribute string should have the form :. Valid diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 335d28e72828..5f11912c587c 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -25,10 +25,13 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage -from flwr.proto.transport_pb2_grpc import FlowerServiceStub +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) +from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611 # The following flags can be uncommented for debugging. Other possible values: # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md @@ -119,7 +122,7 @@ def receive() -> TaskIns: return TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index e5944230e5af..bcfa76bb36c0 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,8 +23,11 @@ import grpc -from flwr.proto.task_pb2 import Task, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.client_manager import SimpleClientManager from flwr.server.fleet.grpc_bidi.grpc_server import start_grpc_server diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 30d407a52c53..cb1a7021dc9d 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -29,15 +29,15 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.grpc import create_channel from flwr.common.logger import log, warn_experimental_feature -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, ) -from flwr.proto.fleet_pb2_grpc import FleetStub -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611 +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 KEY_NODE = "node" KEY_TASK_INS = "current_task_ins" diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 0f3070cfb01a..8cfe909c1738 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -28,12 +28,21 @@ get_server_message_from_task_ins, wrap_client_message_in_task_res, ) +from flwr.client.run_state import RunState from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn -from flwr.client.workload_state import WorkloadState from flwr.common import serde -from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage +from flwr.proto.task_pb2 import ( # pylint: disable=E0611 + SecureAggregation, + Task, + TaskIns, + TaskRes, +) +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + Reason, + ServerMessage, +) class UnexpectedServerMessage(Exception): @@ -79,16 +88,16 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: def handle( - client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns -) -> Tuple[TaskRes, WorkloadState]: + client_fn: ClientFn, state: RunState, task_ins: TaskIns +) -> Tuple[TaskRes, RunState]: """Handle incoming TaskIns from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : WorkloadState - A dataclass storing the state for the workload being executed by the client. + state : RunState + A dataclass storing the state for the run being executed by the client. task_ins: TaskIns The task instruction coming from the server, to be processed by the client. @@ -112,7 +121,7 @@ def handle( task_res = TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task( ancestry=[], sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), @@ -126,16 +135,16 @@ def handle( def handle_legacy_message( - client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage -) -> Tuple[ClientMessage, WorkloadState]: + client_fn: ClientFn, state: RunState, server_msg: ServerMessage +) -> Tuple[ClientMessage, RunState]: """Handle incoming messages from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : WorkloadState - A dataclass storing the state for the workload being executed by the client. + state : RunState + A dataclass storing the state for the run being executed by the client. server_msg: ServerMessage The message coming from the server, to be processed by the client. diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index d7f410d81fc0..194f75fe30ca 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -18,8 +18,8 @@ import uuid from flwr.client import Client +from flwr.client.run_state import RunState from flwr.client.typing import ClientFn -from flwr.client.workload_state import WorkloadState from flwr.common import ( EvaluateIns, EvaluateRes, @@ -33,9 +33,14 @@ serde, typing, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, Code, ServerMessage, Status +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + Code, + ServerMessage, + Status, +) from .message_handler import handle, handle_control_message @@ -121,7 +126,7 @@ def test_client_without_get_properties() -> None: task_ins: TaskIns = TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), @@ -136,7 +141,7 @@ def test_client_without_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=WorkloadState(state={}), + state=RunState(state={}), task_ins=task_ins, ) @@ -152,7 +157,7 @@ def test_client_without_get_properties() -> None: TaskRes( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, ) ) # pylint: disable=no-member @@ -189,7 +194,7 @@ def test_client_with_get_properties() -> None: task_ins = TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), @@ -204,7 +209,7 @@ def test_client_with_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=WorkloadState(state={}), + state=RunState(state={}), task_ins=task_ins, ) @@ -220,7 +225,7 @@ def test_client_with_get_properties() -> None: TaskRes( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, ) ) # pylint: disable=no-member diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index fc24539998c0..667cb9c98d46 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -17,10 +17,13 @@ from typing import Optional -from flwr.proto.fleet_pb2 import PullTaskInsResponse -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool: @@ -70,7 +73,7 @@ def validate_task_res(task_res: TaskRes) -> bool: Returns ------- is_valid: bool - True if the `task_id`, `group_id`, and `workload_id` fields in TaskRes + True if the `task_id`, `group_id`, and `run_id` fields in TaskRes and the `producer`, `consumer`, and `ancestry` fields in its sub-message Task are not initialized accidentally elsewhere, False otherwise. @@ -80,11 +83,10 @@ def validate_task_res(task_res: TaskRes) -> bool: initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()} # Check if certain fields are already initialized - # pylint: disable-next=too-many-boolean-expressions - if ( + if ( # pylint: disable-next=too-many-boolean-expressions "task_id" in initialized_fields_in_task_res or "group_id" in initialized_fields_in_task_res - or "workload_id" in initialized_fields_in_task_res + or "run_id" in initialized_fields_in_task_res or "producer" in initialized_fields_in_task or "consumer" in initialized_fields_in_task or "ancestry" in initialized_fields_in_task @@ -129,7 +131,7 @@ def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes: return TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task(ancestry=[], legacy_client_message=client_message), ) @@ -139,7 +141,7 @@ def configure_task_res( ) -> TaskRes: """Set the metadata of a TaskRes. - Fill `group_id` and `workload_id` in TaskRes + Fill `group_id` and `run_id` in TaskRes and `producer`, `consumer`, and `ancestry` in Task in TaskRes. `producer` in Task in TaskRes will remain unchanged/unset. @@ -152,7 +154,7 @@ def configure_task_res( task_res = TaskRes( task_id="", # This will be generated by the server group_id=ref_task_ins.group_id, - workload_id=ref_task_ins.workload_id, + run_id=ref_task_ins.run_id, task=task_res.task, ) # pylint: disable-next=no-member diff --git a/src/py/flwr/client/message_handler/task_handler_test.py b/src/py/flwr/client/message_handler/task_handler_test.py index 21f3a2ead98a..c1111d0935c0 100644 --- a/src/py/flwr/client/message_handler/task_handler_test.py +++ b/src/py/flwr/client/message_handler/task_handler_test.py @@ -22,9 +22,17 @@ validate_task_res, wrap_client_message_in_task_res, ) -from flwr.proto.fleet_pb2 import PullTaskInsResponse -from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 +from flwr.proto.task_pb2 import ( # pylint: disable=E0611 + SecureAggregation, + Task, + TaskIns, + TaskRes, +) +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) def test_validate_task_ins_no_task() -> None: @@ -92,7 +100,7 @@ def test_validate_task_res() -> None: assert not validate_task_res(task_res) task_res.Clear() - task_res.workload_id = 61016 + task_res.run_id = 61016 assert not validate_task_res(task_res) task_res.Clear() diff --git a/dev/publish.sh b/src/py/flwr/client/middleware/__init__.py old mode 100755 new mode 100644 similarity index 78% rename from dev/publish.sh rename to src/py/flwr/client/middleware/__init__.py index fb4df1694530..58b31296fbbe --- a/dev/publish.sh +++ b/src/py/flwr/client/middleware/__init__.py @@ -1,6 +1,4 @@ -#!/bin/bash - -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Middleware layers.""" + -set -e -cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../ +from .utils import make_ffn -python -m poetry publish +__all__ = [ + "make_ffn", +] diff --git a/src/py/flwr/client/middleware/utils.py b/src/py/flwr/client/middleware/utils.py new file mode 100644 index 000000000000..d93132403c1e --- /dev/null +++ b/src/py/flwr/client/middleware/utils.py @@ -0,0 +1,35 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for middleware layers.""" + + +from typing import List + +from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer + + +def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable: + """.""" + + def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable: + def new_ffn(fwd: Fwd) -> Bwd: + return _layer(fwd, _ffn) + + return new_ffn + + for layer in reversed(layers): + ffn = wrap_ffn(ffn, layer) + + return ffn diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py new file mode 100644 index 000000000000..006fe6db4799 --- /dev/null +++ b/src/py/flwr/client/middleware/utils_test.py @@ -0,0 +1,99 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the utility functions.""" + + +import unittest +from typing import List + +from flwr.client.run_state import RunState +from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 + +from .utils import make_ffn + + +def make_mock_middleware(name: str, footprint: List[str]) -> Layer: + """Make a mock middleware layer.""" + + def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd: + footprint.append(name) + fwd.task_ins.task_id += f"{name}" + bwd = app(fwd) + footprint.append(name) + bwd.task_res.task_id += f"{name}" + return bwd + + return middleware + + +def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: + """Make a mock app.""" + + def app(fwd: Fwd) -> Bwd: + footprint.append(name) + fwd.task_ins.task_id += f"{name}" + return Bwd(task_res=TaskRes(task_id=name), state=RunState({})) + + return app + + +class TestMakeApp(unittest.TestCase): + """Tests for the `make_app` function.""" + + def test_multiple_middlewares(self) -> None: + """Test if multiple middlewares are called in the correct order.""" + # Prepare + footprint: List[str] = [] + mock_app = make_mock_app("app", footprint) + mock_middleware_names = [f"middleware{i}" for i in range(1, 15)] + mock_middleware_layers = [ + make_mock_middleware(name, footprint) for name in mock_middleware_names + ] + task_ins = TaskIns() + + # Execute + wrapped_app = make_ffn(mock_app, mock_middleware_layers) + task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res + + # Assert + trace = mock_middleware_names + ["app"] + self.assertEqual(footprint, trace + list(reversed(mock_middleware_names))) + # pylint: disable-next=no-member + self.assertEqual(task_ins.task_id, "".join(trace)) + self.assertEqual(task_res.task_id, "".join(reversed(trace))) + + def test_filter(self) -> None: + """Test if a middleware can filter incoming TaskIns.""" + # Prepare + footprint: List[str] = [] + mock_app = make_mock_app("app", footprint) + task_ins = TaskIns() + + def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: + footprint.append("filter") + fwd.task_ins.task_id += "filter" + # Skip calling app + return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({})) + + # Execute + wrapped_app = make_ffn(mock_app, [filter_layer]) + task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res + + # Assert + self.assertEqual(footprint, ["filter"]) + # pylint: disable-next=no-member + self.assertEqual(task_ins.task_id, "filter") + self.assertEqual(task_res.task_id, "filter") diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py new file mode 100644 index 000000000000..0a29be511806 --- /dev/null +++ b/src/py/flwr/client/node_state.py @@ -0,0 +1,48 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Node state.""" + + +from typing import Any, Dict + +from flwr.client.run_state import RunState + + +class NodeState: + """State of a node where client nodes execute runs.""" + + def __init__(self) -> None: + self._meta: Dict[str, Any] = {} # holds metadata about the node + self.run_states: Dict[int, RunState] = {} + + def register_runstate(self, run_id: int) -> None: + """Register new run state for this node.""" + if run_id not in self.run_states: + self.run_states[run_id] = RunState({}) + + def retrieve_runstate(self, run_id: int) -> RunState: + """Get run state given a run_id.""" + if run_id in self.run_states: + return self.run_states[run_id] + + raise RuntimeError( + f"RunState for run_id={run_id} doesn't exist." + " A run must be registered before it can be retrieved or updated " + " by a client." + ) + + def update_runstate(self, run_id: int, run_state: RunState) -> None: + """Update run state.""" + self.run_states[run_id] = run_state diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py new file mode 100644 index 000000000000..7bc0d77d16cf --- /dev/null +++ b/src/py/flwr/client/node_state_tests.py @@ -0,0 +1,59 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Node state tests.""" + + +from flwr.client.node_state import NodeState +from flwr.client.run_state import RunState +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 + + +def _run_dummy_task(state: RunState) -> RunState: + if "counter" in state.state: + state.state["counter"] += "1" + else: + state.state["counter"] = "1" + + return state + + +def test_multirun_in_node_state() -> None: + """Test basic NodeState logic.""" + # Tasks to perform + tasks = [TaskIns(run_id=run_id) for run_id in [0, 1, 1, 2, 3, 2, 1, 5]] + # the "tasks" is to count how many times each run is executed + expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} + + # NodeState + node_state = NodeState() + + for task in tasks: + run_id = task.run_id + + # Register + node_state.register_runstate(run_id=run_id) + + # Get run state + state = node_state.retrieve_runstate(run_id=run_id) + + # Run "task" + updated_state = _run_dummy_task(state) + + # Update run state + node_state.update_runstate(run_id=run_id, run_state=updated_state) + + # Verify values + for run_id, state in node_state.run_states.items(): + assert state.state["counter"] == expected_values[run_id] diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index 8b0893ea30aa..d67fb90512d4 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -19,7 +19,7 @@ from typing import Callable, Dict, Tuple from flwr.client.client import Client -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.common import ( Config, NDArrays, @@ -70,7 +70,7 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" - state: WorkloadState + state: RunState def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -174,12 +174,12 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} - def get_state(self) -> WorkloadState: - """Get the workload state from this client.""" + def get_state(self) -> RunState: + """Get the run state from this client.""" return self.state - def set_state(self, state: WorkloadState) -> None: - """Apply a workload state to this client.""" + def set_state(self, state: RunState) -> None: + """Apply a run state to this client.""" self.state = state def to_client(self) -> Client: @@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) # Return FitRes parameters_prime, num_examples, metrics = results @@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) # Return EvaluateRes loss, num_examples, metrics = results @@ -278,12 +278,12 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) -def _get_state(self: Client) -> WorkloadState: +def _get_state(self: Client) -> RunState: """Return state of underlying NumPyClient.""" return self.numpy_client.get_state() # type: ignore -def _set_state(self: Client, state: WorkloadState) -> None: +def _set_state(self: Client, state: RunState) -> None: """Apply state to underlying NumPyClient.""" self.numpy_client.set_state(state) # type: ignore diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index d22b246dbd61..bb55f130f1a8 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -29,7 +29,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.constant import MISSING_EXTRA_REST from flwr.common.logger import log -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, @@ -38,8 +38,8 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 try: import requests @@ -143,6 +143,7 @@ def create_node() -> None: }, data=create_node_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -185,6 +186,7 @@ def delete_node() -> None: }, data=delete_node_req_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -225,6 +227,7 @@ def receive() -> Optional[TaskIns]: }, data=pull_task_ins_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -303,6 +306,7 @@ def send(task_res: TaskRes) -> None: }, data=push_task_res_request_bytes, verify=verify, + timeout=None, ) state[KEY_TASK_INS] = None diff --git a/src/py/flwr/client/workload_state.py b/src/py/flwr/client/run_state.py similarity index 88% rename from src/py/flwr/client/workload_state.py rename to src/py/flwr/client/run_state.py index 42ae2a925f47..c2755eb995eb 100644 --- a/src/py/flwr/client/workload_state.py +++ b/src/py/flwr/client/run_state.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Workload state.""" +"""Run state.""" from dataclasses import dataclass from typing import Dict @dataclass -class WorkloadState: - """State of a workload executed by a client node.""" +class RunState: + """State of a run executed by a client node.""" state: Dict[str, str] diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py index efbb00a9d916..4b74c1ace3de 100644 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py +++ b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py @@ -333,7 +333,7 @@ def _share_keys( # Check if the size is larger than threshold if len(state.public_keys_dict) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") # Check if all public keys are unique pk_list: List[bytes] = [] @@ -341,14 +341,14 @@ def _share_keys( pk_list.append(pk1) pk_list.append(pk2) if len(set(pk_list)) != len(pk_list): - raise Exception("Some public keys are identical") + raise ValueError("Some public keys are identical") # Check if public keys of this client are correct in the dictionary if ( state.public_keys_dict[state.sid][0] != state.pk1 or state.public_keys_dict[state.sid][1] != state.pk2 ): - raise Exception( + raise ValueError( "Own public keys are displayed in dict incorrectly, should not happen!" ) @@ -393,7 +393,7 @@ def _collect_masked_input( ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST]) srcs = cast(List[int], named_values[KEY_SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: - raise Exception("Not enough available neighbour clients.") + raise ValueError("Not enough available neighbour clients.") # Decrypt ciphertexts, verify their sources, and store shares. for src, ciphertext in zip(srcs, ciphertexts): @@ -409,7 +409,7 @@ def _collect_masked_input( f"from {actual_src} instead of {src}." ) if dst != state.sid: - ValueError( + raise ValueError( f"Client {state.sid}: received an encrypted message" f"for Client {dst} from Client {src}." ) @@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, # Send private mask seed share for every avaliable client (including itclient) # Send first private key share for building pairwise mask for every dropped client if len(active_sids) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") sids, shares = [], [] sids += active_sids diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 2c1f7506592c..5291afb83d98 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -17,8 +17,8 @@ from dataclasses import dataclass from typing import Callable -from flwr.client.workload_state import WorkloadState -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.client.run_state import RunState +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .client import Client as Client @@ -28,7 +28,7 @@ class Fwd: """.""" task_ins: TaskIns - state: WorkloadState + state: RunState @dataclass @@ -36,8 +36,9 @@ class Bwd: """.""" task_res: TaskRes - state: WorkloadState + state: RunState FlowerCallable = Callable[[Fwd], Bwd] ClientFn = Callable[[str], Client] +Layer = Callable[[Fwd, FlowerCallable], Bwd] diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py new file mode 100644 index 000000000000..b0480841e06c --- /dev/null +++ b/src/py/flwr/common/configsrecord.py @@ -0,0 +1,116 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ConfigsRecord.""" + + +from dataclasses import dataclass, field +from typing import Dict, Optional, get_args + +from .typing import ConfigsRecordValues, ConfigsScalar + + +@dataclass +class ConfigsRecord: + """Configs record.""" + + data: Dict[str, ConfigsRecordValues] = field(default_factory=dict) + + def __init__( + self, + configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None, + keep_input: bool = True, + ): + """Construct a ConfigsRecord object. + + Parameters + ---------- + configs_dict : Optional[Dict[str, ConfigsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as + defined in `ConfigsScalar`) and lists of such types (see + `ConfigsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether config passed should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + self.data = {} + if configs_dict: + self.set_configs(configs_dict, keep_input=keep_input) + + def set_configs( + self, configs_dict: Dict[str, ConfigsRecordValues], keep_input: bool = True + ) -> None: + """Add configs to the record. + + Parameters + ---------- + configs_dict : Dict[str, ConfigsRecordValues] + A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as + defined in `ConfigsRecordValues`) and list of such types (see + `ConfigsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether config passed should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + if any(not isinstance(k, str) for k in configs_dict.keys()): + raise TypeError(f"Not all keys are of valid type. Expected {str}") + + def is_valid(value: ConfigsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(value, get_args(ConfigsScalar)): + raise TypeError( + "Not all values are of valid type." + f" Expected {ConfigsRecordValues} but you passed {type(value)}." + ) + + # Check types of values + # Split between those values that are list and those that aren't + # then process in the same way + for value in configs_dict.values(): + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such config as + # an array and pass it to a ParametersRecord. + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {ConfigsScalar}." + ) + else: + is_valid(value) + + # Add configs to record + if keep_input: + # Copy + self.data = configs_dict.copy() + else: + # Add entries to dataclass without duplicating memory + for key in list(configs_dict.keys()): + self.data[key] = configs_dict[key] + del configs_dict[key] + + def __getitem__(self, key: str) -> ConfigsRecordValues: + """Retrieve an element stored in record.""" + return self.data[key] diff --git a/src/py/flwr/common/flowercontext.py b/src/py/flwr/common/flowercontext.py new file mode 100644 index 000000000000..6e26d93bfe9a --- /dev/null +++ b/src/py/flwr/common/flowercontext.py @@ -0,0 +1,77 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FlowerContext and Metadata.""" + + +from dataclasses import dataclass + +from .recordset import RecordSet + + +@dataclass +class Metadata: + """A dataclass holding metadata associated with the current task. + + Parameters + ---------- + run_id : int + An identifier for the current run. + task_id : str + An identifier for the current task. + group_id : str + An identifier for grouping tasks. In some settings + this is used as the FL round. + ttl : str + Time-to-live for this task. + task_type : str + A string that encodes the action to be executed on + the receiving end. + """ + + run_id: int + task_id: str + group_id: str + ttl: str + task_type: str + + +@dataclass +class FlowerContext: + """State of your application from the viewpoint of the entity using it. + + Parameters + ---------- + in_message : RecordSet + Holds records sent by another entity (e.g. sent by the server-side + logic to a client, or vice-versa) + out_message : RecordSet + Holds records added by the current entity. This `RecordSet` will + be sent out (e.g. back to the server-side for aggregation of + parameter, or to the client to perform a certain task) + local : RecordSet + Holds record added by the current entity and that will stay local. + This means that the data it holds will never leave the system it's running from. + This can be used as an intermediate storage or scratchpad when + executing middleware layers. It can also be used as a memory to access + at different points during the lifecycle of this entity (e.g. across + multiple rounds) + metadata : Metadata + A dataclass including information about the task to be executed. + """ + + in_message: RecordSet + out_message: RecordSet + local: RecordSet + metadata: Metadata diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 29d1562a86d3..50c902da38b5 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None: """, name, ) + + +def warn_deprecated_feature(name: str) -> None: + """Warn the user when they use a deprecated feature.""" + log( + WARN, + """ + DEPRECATED FEATURE: %s + + This is a deprecated feature. It will be removed + entirely in future versions of Flower. + """, + name, + ) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py new file mode 100644 index 000000000000..e70b0cb31d55 --- /dev/null +++ b/src/py/flwr/common/metricsrecord.py @@ -0,0 +1,116 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MetricsRecord.""" + + +from dataclasses import dataclass, field +from typing import Dict, Optional, get_args + +from .typing import MetricsRecordValues, MetricsScalar + + +@dataclass +class MetricsRecord: + """Metrics record.""" + + data: Dict[str, MetricsRecordValues] = field(default_factory=dict) + + def __init__( + self, + metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, + keep_input: bool = True, + ): + """Construct a MetricsRecord object. + + Parameters + ---------- + metrics_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether metrics should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + self.data = {} + if metrics_dict: + self.set_metrics(metrics_dict, keep_input=keep_input) + + def set_metrics( + self, metrics_dict: Dict[str, MetricsRecordValues], keep_input: bool = True + ) -> None: + """Add metrics to the record. + + Parameters + ---------- + metrics_dict : Dict[str, MetricsRecordValues] + A dictionary that stores basic types (i.e. `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether metrics should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + if any(not isinstance(k, str) for k in metrics_dict.keys()): + raise TypeError(f"Not all keys are of valid type. Expected {str}.") + + def is_valid(value: MetricsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(value, get_args(MetricsScalar)) or isinstance( + value, bool + ): + raise TypeError( + "Not all values are of valid type." + f" Expected {MetricsRecordValues} but you passed {type(value)}." + ) + + # Check types of values + # Split between those values that are list and those that aren't + # then process in the same way + for value in metrics_dict.values(): + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such metric as + # an array and pass it to a ParametersRecord. + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {MetricsScalar}." + ) + else: + is_valid(value) + + # Add metrics to record + if keep_input: + # Copy + self.data = metrics_dict.copy() + else: + # Add entries to dataclass without duplicating memory + for key in list(metrics_dict.keys()): + self.data[key] = metrics_dict[key] + del metrics_dict[key] + + def __getitem__(self, key: str) -> MetricsRecordValues: + """Retrieve an element stored in record.""" + return self.data[key] diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py new file mode 100644 index 000000000000..ef02a0789ddf --- /dev/null +++ b/src/py/flwr/common/parametersrecord.py @@ -0,0 +1,117 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ParametersRecord and Array.""" + + +from dataclasses import dataclass, field +from typing import List, Optional, OrderedDict + + +@dataclass +class Array: + """Array type. + + A dataclass containing serialized data from an array-like or tensor-like object + along with some metadata about it. + + Parameters + ---------- + dtype : str + A string representing the data type of the serialised object (e.g. `np.float32`) + + shape : List[int] + A list representing the shape of the unserialized array-like object. This is + used to deserialize the data (depending on the serialization method) or simply + as a metadata field. + + stype : str + A string indicating the type of serialisation mechanism used to generate the + bytes in `data` from an array-like or tensor-like object. + + data: bytes + A buffer of bytes containing the data. + """ + + dtype: str + shape: List[int] + stype: str + data: bytes + + +@dataclass +class ParametersRecord: + """Parameters record. + + A dataclass storing named Arrays in order. This means that it holds entries as an + OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to + PyTorch's state_dict, but holding serialised tensors instead. + """ + + data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) + + def __init__( + self, + array_dict: Optional[OrderedDict[str, Array]] = None, + keep_input: bool = False, + ) -> None: + """Construct a ParametersRecord object. + + Parameters + ---------- + array_dict : Optional[OrderedDict[str, Array]] + A dictionary that stores serialized array-like or tensor-like objects. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + dictionary immediately after adding them to the record. If False, the + dictionary passed to `set_parameters()` will be empty once exiting from that + function. This is the desired behaviour when working with very large + models/tensors/arrays. However, if you plan to continue working with your + parameters after adding it to the record, set this flag to True. When set + to True, the data is duplicated in memory. + """ + self.data = OrderedDict() + if array_dict: + self.set_parameters(array_dict, keep_input=keep_input) + + def set_parameters( + self, array_dict: OrderedDict[str, Array], keep_input: bool = False + ) -> None: + """Add parameters to record. + + Parameters + ---------- + array_dict : OrderedDict[str, Array] + A dictionary that stores serialized array-like or tensor-like objects. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + dictionary immediately after adding them to the record. + """ + if any(not isinstance(k, str) for k in array_dict.keys()): + raise TypeError(f"Not all keys are of valid type. Expected {str}") + if any(not isinstance(v, Array) for v in array_dict.values()): + raise TypeError(f"Not all values are of valid type. Expected {Array}") + + if keep_input: + # Copy + self.data = OrderedDict(array_dict) + else: + # Add entries to dataclass without duplicating memory + for key in list(array_dict.keys()): + self.data[key] = array_dict[key] + del array_dict[key] + + def __getitem__(self, key: str) -> Array: + """Retrieve an element stored in record.""" + return self.data[key] diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py new file mode 100644 index 000000000000..61c880c970b8 --- /dev/null +++ b/src/py/flwr/common/recordset.py @@ -0,0 +1,68 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet.""" + + +from dataclasses import dataclass, field +from typing import Dict + +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import ParametersRecord + + +@dataclass +class RecordSet: + """Definition of RecordSet.""" + + parameters: Dict[str, ParametersRecord] = field(default_factory=dict) + metrics: Dict[str, MetricsRecord] = field(default_factory=dict) + configs: Dict[str, ConfigsRecord] = field(default_factory=dict) + + def set_parameters(self, name: str, record: ParametersRecord) -> None: + """Add a ParametersRecord.""" + self.parameters[name] = record + + def get_parameters(self, name: str) -> ParametersRecord: + """Get a ParametesRecord.""" + return self.parameters[name] + + def del_parameters(self, name: str) -> None: + """Delete a ParametersRecord.""" + del self.parameters[name] + + def set_metrics(self, name: str, record: MetricsRecord) -> None: + """Add a MetricsRecord.""" + self.metrics[name] = record + + def get_metrics(self, name: str) -> MetricsRecord: + """Get a MetricsRecord.""" + return self.metrics[name] + + def del_metrics(self, name: str) -> None: + """Delete a MetricsRecord.""" + del self.metrics[name] + + def set_configs(self, name: str, record: ConfigsRecord) -> None: + """Add a ConfigsRecord.""" + self.configs[name] = record + + def get_configs(self, name: str) -> ConfigsRecord: + """Get a ConfigsRecord.""" + return self.configs[name] + + def del_configs(self, name: str) -> None: + """Delete a ConfigsRecord.""" + del self.configs[name] diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py new file mode 100644 index 000000000000..83e1e4595f1d --- /dev/null +++ b/src/py/flwr/common/recordset_test.py @@ -0,0 +1,348 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet tests.""" + + +from typing import Callable, Dict, List, OrderedDict, Type, Union + +import numpy as np +import pytest + +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parameter import ndarrays_to_parameters, parameters_to_ndarrays +from .parametersrecord import Array, ParametersRecord +from .recordset_utils import ( + parameters_to_parametersrecord, + parametersrecord_to_parameters, +) +from .typing import ( + ConfigsRecordValues, + MetricsRecordValues, + NDArray, + NDArrays, + Parameters, +) + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +def ndarray_to_array(ndarray: NDArray) -> Array: + """Represent NumPy ndarray as Array.""" + return Array( + data=ndarray.tobytes(), + dtype=str(ndarray.dtype), + stype="numpy.ndarray.tobytes", + shape=list(ndarray.shape), + ) + + +def test_ndarray_to_array() -> None: + """Test creation of Array object from NumPy ndarray.""" + shape = (2, 7, 9) + arr = np.eye(*shape) + + array = ndarray_to_array(arr) + + arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape) + + assert np.array_equal(arr, arr_) + + +def test_parameters_to_array_and_back() -> None: + """Test conversion between legacy Parameters and Array.""" + ndarrays = get_ndarrays() + + # Array represents a single array, unlike Paramters, which represent a + # list of arrays + ndarray = ndarrays[0] + + parameters = ndarrays_to_parameters([ndarray]) + + array = Array( + data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[] + ) + + parameters = Parameters(tensors=[array.data], tensor_type=array.stype) + + ndarray_ = parameters_to_ndarrays(parameters=parameters)[0] + + assert np.array_equal(ndarray, ndarray_) + + +def test_parameters_to_parametersrecord_and_back() -> None: + """Test conversion between legacy Parameters and ParametersRecords.""" + ndarrays = get_ndarrays() + + parameters = ndarrays_to_parameters(ndarrays) + + params_record = parameters_to_parametersrecord(parameters=parameters) + + parameters_ = parametersrecord_to_parameters(params_record) + + ndarrays_ = parameters_to_ndarrays(parameters=parameters_) + + for arr, arr_ in zip(ndarrays, ndarrays_): + assert np.array_equal(arr, arr_) + + +def test_set_parameters_while_keeping_intputs() -> None: + """Tests keep_input functionality in ParametersRecord.""" + # Adding parameters to a record that doesn't erase entries in the input `array_dict` + p_record = ParametersRecord(keep_input=True) + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.set_parameters(array_dict, keep_input=True) + + # Creating a second parametersrecord passing the same `array_dict` (not erased) + p_record_2 = ParametersRecord(array_dict) + assert p_record.data == p_record_2.data + + # Now it should be empty (the second ParametersRecord wasn't flagged to keep it) + assert len(array_dict) == 0 + + +def test_set_parameters_with_correct_types() -> None: + """Test adding dictionary of Arrays to ParametersRecord.""" + p_record = ParametersRecord() + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.set_parameters(array_dict) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # correct key, incorrect value + (str, lambda x: x.tolist()), # correct key, incorrect value + (int, ndarray_to_array), # incorrect key, correct value + (int, lambda x: x), # incorrect key, incorrect value + (int, lambda x: x.tolist()), # incorrect key, incorrect value + ], +) +def test_set_parameters_with_incorrect_types( + key_type: Type[Union[int, str]], + value_fn: Callable[[NDArray], Union[NDArray, List[float]]], +) -> None: + """Test adding dictionary of unsupported types to ParametersRecord.""" + p_record = ParametersRecord() + + array_dict = { + key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays()) + } + + with pytest.raises(TypeError): + p_record.set_parameters(array_dict) # type: ignore + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: int(x.flatten()[0])), # str: int + (str, lambda x: float(x.flatten()[0])), # str: float + (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] + (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + (str, lambda x: []), # str: empty list + ], +) +def test_set_metrics_to_metricsrecord_with_correct_types( + key_type: Type[str], + value_fn: Callable[[NDArray], MetricsRecordValues], +) -> None: + """Test adding metrics of various types to a MetricsRecord.""" + m_record = MetricsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + # Add metric + m_record.set_metrics(my_metrics) + + # Check metrics are actually added + assert my_metrics == m_record.data + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: str(x.flatten()[0])), # str: str (supported: unsupported) + (str, lambda x: bool(x.flatten()[0])), # str: bool (supported: unsupported) + ( + str, + lambda x: x.flatten().astype("str").tolist(), + ), # str: List[str] (supported: unsupported) + (str, lambda x: x), # str: NDArray (supported: unsupported) + ( + str, + lambda x: {str(v): v for v in x.flatten()}, + ), # str: dict[str: float] (supported: unsupported) + ( + str, + lambda x: [{str(v): v for v in x.flatten()}], + ), # str: List[dict[str: float]] (supported: unsupported) + ( + str, + lambda x: [1, 2.0, 3.0, 4], + ), # str: List[mixing valid types] (supported: unsupported) + ( + int, + lambda x: x.flatten().tolist(), + ), # int: List[str] (unsupported: supported) + ( + float, + lambda x: x.flatten().tolist(), + ), # float: List[int] (unsupported: supported) + ], +) +def test_set_metrics_to_metricsrecord_with_incorrect_types( + key_type: Type[Union[str, int, float, bool]], + value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], +) -> None: + """Test adding metrics of various unsupported types to a MetricsRecord.""" + m_record = MetricsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + with pytest.raises(TypeError): + m_record.set_metrics(my_metrics) # type: ignore + + +@pytest.mark.parametrize( + "keep_input", + [ + (True), + (False), + ], +) +def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( + keep_input: bool, +) -> None: + """Test keep_input functionality for MetricsRecord.""" + m_record = MetricsRecord(keep_input=keep_input) + + # constructing a valid input + labels = [1, 2.0] + arrays = get_ndarrays() + my_metrics = OrderedDict( + {str(label): arr.flatten().tolist() for label, arr in zip(labels, arrays)} + ) + + my_metrics_copy = my_metrics.copy() + + # Add metric + m_record.set_metrics(my_metrics, keep_input=keep_input) + + # Check metrics are actually added + # Check that input dict has been emptied when enabled such behaviour + if keep_input: + assert my_metrics == m_record.data + else: + assert my_metrics_copy == m_record.data + assert len(my_metrics) == 0 + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: str(x.flatten()[0])), # str: str + (str, lambda x: int(x.flatten()[0])), # str: int + (str, lambda x: float(x.flatten()[0])), # str: float + (str, lambda x: bool(x.flatten()[0])), # str: bool + (str, lambda x: x.flatten().tobytes()), # str: bytes + (str, lambda x: x.flatten().astype("str").tolist()), # str: List[str] + (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] + (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + (str, lambda x: x.flatten().astype("bool").tolist()), # str: List[bool] + (str, lambda x: [x.flatten().tobytes()]), # str: List[bytes] + (str, lambda x: []), # str: empyt list + ], +) +def test_set_configs_to_configsrecord_with_correct_types( + key_type: Type[str], + value_fn: Callable[[NDArray], ConfigsRecordValues], +) -> None: + """Test adding configs of various types to a ConfigsRecord.""" + labels = [1, 2.0] + arrays = get_ndarrays() + + my_configs = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + c_record = ConfigsRecord(my_configs) + + # check values are actually there + assert c_record.data == my_configs + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # str: NDArray (supported: unsupported) + ( + str, + lambda x: {str(v): v for v in x.flatten()}, + ), # str: dict[str: float] (supported: unsupported) + ( + str, + lambda x: [{str(v): v for v in x.flatten()}], + ), # str: List[dict[str: float]] (supported: unsupported) + ( + str, + lambda x: [1, 2.0, 3.0, 4], + ), # str: List[mixing valid types] (supported: unsupported) + ( + int, + lambda x: x.flatten().tolist(), + ), # int: List[str] (unsupported: supported) + ( + float, + lambda x: x.flatten().tolist(), + ), # float: List[int] (unsupported: supported) + ], +) +def test_set_configs_to_configsrecord_with_incorrect_types( + key_type: Type[Union[str, int, float]], + value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], +) -> None: + """Test adding configs of various unsupported types to a ConfigsRecord.""" + m_record = ConfigsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + with pytest.raises(TypeError): + m_record.set_configs(my_metrics) # type: ignore diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py new file mode 100644 index 000000000000..c1e724fa2758 --- /dev/null +++ b/src/py/flwr/common/recordset_utils.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet utilities.""" + + +from typing import OrderedDict + +from .parametersrecord import Array, ParametersRecord +from .typing import Parameters + + +def parametersrecord_to_parameters( + record: ParametersRecord, keep_input: bool = False +) -> Parameters: + """Convert ParameterRecord to legacy Parameters. + + Warning: Because `Arrays` in `ParametersRecord` encode more information of the + array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it + might not be possible to reconstruct such data structures from `Parameters` objects + alone. Additional information or metadta must be provided from elsewhere. + + Parameters + ---------- + record : ParametersRecord + The record to be conveted into Parameters. + keep_input : bool (default: False) + A boolean indicating whether entries in the record should be deleted from the + input dictionary immediately after adding them to the record. + """ + parameters = Parameters(tensors=[], tensor_type="") + + for key in list(record.data.keys()): + parameters.tensors.append(record.data[key].data) + + if not keep_input: + del record.data[key] + + return parameters + + +def parameters_to_parametersrecord( + parameters: Parameters, keep_input: bool = False +) -> ParametersRecord: + """Convert legacy Parameters into a single ParametersRecord. + + Because there is no concept of names in the legacy Parameters, arbitrary keys will + be used when constructing the ParametersRecord. Similarly, the shape and data type + won't be recorded in the Array objects. + + Parameters + ---------- + parameters : Parameters + Parameters object to be represented as a ParametersRecord. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + Parameters object (i.e. a list of serialized NumPy arrays) immediately after + adding them to the record. + """ + tensor_type = parameters.tensor_type + + p_record = ParametersRecord() + + num_arrays = len(parameters.tensors) + for idx in range(num_arrays): + if keep_input: + tensor = parameters.tensors[idx] + else: + tensor = parameters.tensors.pop(0) + p_record.set_parameters( + OrderedDict( + {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} + ) + ) + + return p_record diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index a60fff57e7bf..5441e766983a 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -156,6 +156,7 @@ class RetryInvoker: >>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1) """ + # pylint: disable-next=too-many-arguments def __init__( self, wait_factory: Callable[[], Generator[float, None, None]], diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index c8c73e87e04a..2094c76a9856 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -15,8 +15,20 @@ """ProtoBuf serialization and deserialization.""" -from typing import Any, Dict, List, MutableMapping, cast - +from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast + +from google.protobuf.message import Message + +# pylint: disable=E0611 +from flwr.proto.recordset_pb2 import Array as ProtoArray +from flwr.proto.recordset_pb2 import BoolList, BytesList +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord +from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue +from flwr.proto.recordset_pb2 import DoubleList +from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord +from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue +from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import Sint64List, StringList from flwr.proto.task_pb2 import Value from flwr.proto.transport_pb2 import ( ClientMessage, @@ -28,7 +40,11 @@ Status, ) +# pylint: enable=E0611 from . import typing +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord # === ServerMessage message === @@ -59,7 +75,9 @@ def server_message_to_proto(server_message: typing.ServerMessage) -> ServerMessa server_message.evaluate_ins, ) ) - raise Exception("No instruction set in ServerMessage, cannot serialize to ProtoBuf") + raise ValueError( + "No instruction set in ServerMessage, cannot serialize to ProtoBuf" + ) def server_message_from_proto( @@ -91,7 +109,7 @@ def server_message_from_proto( server_message_proto.evaluate_ins, ) ) - raise Exception( + raise ValueError( "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" ) @@ -125,7 +143,9 @@ def client_message_to_proto(client_message: typing.ClientMessage) -> ClientMessa client_message.evaluate_res, ) ) - raise Exception("No instruction set in ClientMessage, cannot serialize to ProtoBuf") + raise ValueError( + "No instruction set in ClientMessage, cannot serialize to ProtoBuf" + ) def client_message_from_proto( @@ -157,7 +177,7 @@ def client_message_from_proto( client_message_proto.evaluate_res, ) ) - raise Exception( + raise ValueError( "Unsupported instruction in ClientMessage, cannot deserialize from ProtoBuf" ) @@ -474,7 +494,7 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar: if isinstance(scalar, str): return Scalar(string=scalar) - raise Exception( + raise ValueError( f"Accepted types: {bool, bytes, float, int, str} (but not {type(scalar)})" ) @@ -489,7 +509,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: # === Value messages === -_python_type_to_field_name = { +_type_to_field = { float: "double", int: "sint64", bool: "bool", @@ -498,27 +518,25 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: } -_python_list_type_to_message_and_field_name = { - float: (Value.DoubleList, "double_list"), - int: (Value.Sint64List, "sint64_list"), - bool: (Value.BoolList, "bool_list"), - str: (Value.StringList, "string_list"), - bytes: (Value.BytesList, "bytes_list"), +_list_type_to_class_and_field = { + float: (DoubleList, "double_list"), + int: (Sint64List, "sint64_list"), + bool: (BoolList, "bool_list"), + str: (StringList, "string_list"), + bytes: (BytesList, "bytes_list"), } def _check_value(value: typing.Value) -> None: - if isinstance(value, tuple(_python_type_to_field_name.keys())): + if isinstance(value, tuple(_type_to_field.keys())): return if isinstance(value, list): - if len(value) > 0 and isinstance( - value[0], tuple(_python_type_to_field_name.keys()) - ): + if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())): data_type = type(value[0]) for element in value: if isinstance(element, data_type): continue - raise Exception( + raise TypeError( f"Inconsistent type: the types of elements in the list must " f"be the same (expected {data_type}, but got {type(element)})." ) @@ -535,12 +553,12 @@ def value_to_proto(value: typing.Value) -> Value: arg = {} if isinstance(value, list): - msg_class, field_name = _python_list_type_to_message_and_field_name[ + msg_class, field_name = _list_type_to_class_and_field[ type(value[0]) if len(value) > 0 else int ] arg[field_name] = msg_class(vals=value) else: - arg[_python_type_to_field_name[type(value)]] = value + arg[_type_to_field[type(value)]] = value return Value(**arg) @@ -569,3 +587,135 @@ def named_values_from_proto( ) -> Dict[str, typing.Value]: """Deserialize named values from ProtoBuf.""" return {name: value_from_proto(value) for name, value in named_values_proto.items()} + + +# === Record messages === + + +T = TypeVar("T") + + +def _record_value_to_proto( + value: Any, allowed_types: List[type], proto_class: Type[T] +) -> T: + """Serialize `*RecordValue` to ProtoBuf.""" + arg = {} + for t in allowed_types: + # Single element + # Note: `isinstance(False, int) == True`. + if type(value) == t: # pylint: disable=C0123 + arg[_type_to_field[t]] = value + return proto_class(**arg) + # List + if isinstance(value, list) and all(isinstance(item, t) for item in value): + list_class, field_name = _list_type_to_class_and_field[t] + arg[field_name] = list_class(vals=value) + return proto_class(**arg) + # Invalid types + raise TypeError( + f"The type of the following value is not allowed " + f"in '{proto_class.__name__}':\n{value}" + ) + + +def _record_value_from_proto(value_proto: Message) -> Any: + """Deserialize `*RecordValue` from ProtoBuf.""" + value_field = cast(str, value_proto.WhichOneof("value")) + if value_field.endswith("list"): + value = list(getattr(value_proto, value_field).vals) + else: + value = getattr(value_proto, value_field) + return value + + +def _record_value_dict_to_proto( + value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T] +) -> Dict[str, T]: + """Serialize the record value dict to ProtoBuf.""" + + def proto(_v: Any) -> T: + return _record_value_to_proto(_v, allowed_types, value_proto_class) + + return {k: proto(v) for k, v in value_dict.items()} + + +def _record_value_dict_from_proto( + value_dict_proto: MutableMapping[str, Any] +) -> Dict[str, Any]: + """Deserialize the record value dict from ProtoBuf.""" + return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()} + + +def array_to_proto(array: Array) -> ProtoArray: + """Serialize Array to ProtoBuf.""" + return ProtoArray(**vars(array)) + + +def array_from_proto(array_proto: ProtoArray) -> Array: + """Deserialize Array from ProtoBuf.""" + return Array( + dtype=array_proto.dtype, + shape=list(array_proto.shape), + stype=array_proto.stype, + data=array_proto.data, + ) + + +def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord: + """Serialize ParametersRecord to ProtoBuf.""" + return ProtoParametersRecord( + data_keys=record.data.keys(), + data_values=map(array_to_proto, record.data.values()), + ) + + +def parameters_record_from_proto( + record_proto: ProtoParametersRecord, +) -> ParametersRecord: + """Deserialize ParametersRecord from ProtoBuf.""" + return ParametersRecord( + array_dict=OrderedDict( + zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values)) + ), + keep_input=False, + ) + + +def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord: + """Serialize MetricsRecord to ProtoBuf.""" + return ProtoMetricsRecord( + data=_record_value_dict_to_proto( + record.data, [float, int], ProtoMetricsRecordValue + ) + ) + + +def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord: + """Deserialize MetricsRecord from ProtoBuf.""" + return MetricsRecord( + metrics_dict=cast( + Dict[str, typing.MetricsRecordValues], + _record_value_dict_from_proto(record_proto.data), + ), + keep_input=False, + ) + + +def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord: + """Serialize ConfigsRecord to ProtoBuf.""" + return ProtoConfigsRecord( + data=_record_value_dict_to_proto( + record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue + ) + ) + + +def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord: + """Deserialize ConfigsRecord from ProtoBuf.""" + return ConfigsRecord( + configs_dict=cast( + Dict[str, typing.ConfigsRecordValues], + _record_value_dict_from_proto(record_proto.data), + ), + keep_input=False, + ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index ba07890f4658..c584597d89f6 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -15,14 +15,31 @@ """(De-)serialization tests.""" -from typing import Dict, Union, cast +from typing import Dict, OrderedDict, Union, cast -from flwr.common import typing +# pylint: disable=E0611 from flwr.proto import transport_pb2 as pb2 - +from flwr.proto.recordset_pb2 import Array as ProtoArray +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord +from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord +from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord + +# pylint: enable=E0611 +from . import typing +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord from .serde import ( + array_from_proto, + array_to_proto, + configs_record_from_proto, + configs_record_to_proto, + metrics_record_from_proto, + metrics_record_to_proto, named_values_from_proto, named_values_to_proto, + parameters_record_from_proto, + parameters_record_to_proto, scalar_from_proto, scalar_to_proto, status_from_proto, @@ -50,8 +67,8 @@ def test_serialisation_deserialisation() -> None: def test_status_to_proto() -> None: """Test status message (de-)serialization.""" # Prepare - code_msg = pb2.Code.OK - status_msg = pb2.Status(code=code_msg, message="Success") + code_msg = pb2.Code.OK # pylint: disable=E1101 + status_msg = pb2.Status(code=code_msg, message="Success") # pylint: disable=E1101 code = typing.Code.OK status = typing.Status(code=code, message="Success") @@ -66,8 +83,8 @@ def test_status_to_proto() -> None: def test_status_from_proto() -> None: """Test status message (de-)serialization.""" # Prepare - code_msg = pb2.Code.OK - status_msg = pb2.Status(code=code_msg, message="Success") + code_msg = pb2.Code.OK # pylint: disable=E1101 + status_msg = pb2.Status(code=code_msg, message="Success") # pylint: disable=E1101 code = typing.Code.OK status = typing.Status(code=code, message="Success") @@ -157,3 +174,71 @@ def test_named_values_serialization_deserialization() -> None: assert elm1 == elm2 else: assert expected == actual + + +def test_array_serialization_deserialization() -> None: + """Test serialization and deserialization of Array.""" + # Prepare + original = Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234") + + # Execute + proto = array_to_proto(original) + deserialized = array_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoArray) + assert original == deserialized + + +def test_parameters_record_serialization_deserialization() -> None: + """Test serialization and deserialization of ParametersRecord.""" + # Prepare + original = ParametersRecord( + array_dict=OrderedDict( + [ + ("k1", Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234")), + ("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")), + ] + ), + keep_input=False, + ) + + # Execute + proto = parameters_record_to_proto(original) + deserialized = parameters_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoParametersRecord) + assert original.data == deserialized.data + + +def test_metrics_record_serialization_deserialization() -> None: + """Test serialization and deserialization of MetricsRecord.""" + # Prepare + original = MetricsRecord( + metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False + ) + + # Execute + proto = metrics_record_to_proto(original) + deserialized = metrics_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoMetricsRecord) + assert original.data == deserialized.data + + +def test_configs_record_serialization_deserialization() -> None: + """Test serialization and deserialization of ConfigsRecord.""" + # Prepare + original = ConfigsRecord( + configs_dict={"learning_rate": 0.01, "batch_size": 32}, keep_input=False + ) + + # Execute + proto = configs_record_to_proto(original) + deserialized = configs_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoConfigsRecord) + assert original.data == deserialized.data diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index d56726d83378..fed8b5a978bc 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -152,6 +152,10 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A START_DRIVER_ENTER = auto() START_DRIVER_LEAVE = auto() + # SuperNode: flower-client + RUN_CLIENT_ENTER = auto() + RUN_CLIENT_LEAVE = auto() + # Use the ThreadPoolExecutor with max_workers=1 to have a queue # and also ensure that telemetry calls are not blocking. diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 6c0266f5eec8..d6b2ec9b158c 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -45,6 +45,15 @@ List[str], ] +# Value types for common.MetricsRecord +MetricsScalar = Union[int, float] +MetricsScalarList = Union[List[int], List[float]] +MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] +# Value types for common.ConfigsRecord +ConfigsScalar = Union[MetricsScalar, str, bytes, bool] +ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes], List[bool]] +ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList] + Metrics = Dict[str, Scalar] MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics] diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/driver/app.py index 3cb8652365d8..4fa1ad8b5c02 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/driver/app.py @@ -25,7 +25,7 @@ from flwr.common import EventType, event from flwr.common.address import parse_address from flwr.common.logger import log -from flwr.proto import driver_pb2 +from flwr.proto import driver_pb2 # pylint: disable=E0611 from flwr.server.app import ServerConfig, init_defaults, run_fl from flwr.server.client_manager import ClientManager from flwr.server.history import History @@ -170,8 +170,10 @@ def update_client_manager( and dead nodes will be removed from the ClientManager via `client_manager.unregister()`. """ - # Request for workload_id - workload_id = driver.create_workload(driver_pb2.CreateWorkloadRequest()).workload_id + # Request for run_id + run_id = driver.create_run( + driver_pb2.CreateRunRequest() # pylint: disable=E1101 + ).run_id # Loop until the driver is disconnected registered_nodes: Dict[int, DriverClientProxy] = {} @@ -181,7 +183,7 @@ def update_client_manager( if driver.stub is None: break get_nodes_res = driver.get_nodes( - req=driver_pb2.GetNodesRequest(workload_id=workload_id) + req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101 ) all_node_ids = {node.node_id for node in get_nodes_res.nodes} dead_nodes = set(registered_nodes).difference(all_node_ids) @@ -199,7 +201,7 @@ def update_client_manager( node_id=node_id, driver=driver, anonymous=False, - workload_id=workload_id, + run_id=run_id, ) if client_manager.register(client_proxy): registered_nodes[node_id] = client_proxy diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py index 91b4fd30bc4b..bfa0098f68e2 100644 --- a/src/py/flwr/driver/app_test.py +++ b/src/py/flwr/driver/app_test.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Flower Driver app tests.""" -# pylint: disable=no-self-use import threading @@ -22,8 +21,11 @@ from unittest.mock import MagicMock from flwr.driver.app import update_client_manager -from flwr.proto.driver_pb2 import CreateWorkloadResponse, GetNodesResponse -from flwr.proto.node_pb2 import Node +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunResponse, + GetNodesResponse, +) +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.server.client_manager import SimpleClientManager @@ -43,7 +45,7 @@ def test_simple_client_manager_update(self) -> None: ] driver = MagicMock() driver.stub = "driver stub" - driver.create_workload.return_value = CreateWorkloadResponse(workload_id=1) + driver.create_run.return_value = CreateRunResponse(run_id=1) driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes) client_manager = SimpleClientManager() lock = threading.Lock() @@ -76,7 +78,7 @@ def test_simple_client_manager_update(self) -> None: driver.stub = None # Assert - driver.create_workload.assert_called_once() + driver.create_run.assert_called_once() assert node_ids == {node.node_id for node in expected_nodes} assert updated_node_ids == {node.node_id for node in expected_updated_nodes} diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index f1a7c6663c11..512a2001165e 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -18,14 +18,14 @@ from typing import Iterable, List, Optional, Tuple from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver -from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, GetNodesRequest, PullTaskResRequest, PushTaskInsRequest, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 class Driver: @@ -54,37 +54,37 @@ def __init__( self.addr = driver_service_address self.certificates = certificates self.grpc_driver: Optional[GrpcDriver] = None - self.workload_id: Optional[int] = None + self.run_id: Optional[int] = None self.node = Node(node_id=0, anonymous=True) - def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]: + def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]: # Check if the GrpcDriver is initialized - if self.grpc_driver is None or self.workload_id is None: - # Connect and create workload + if self.grpc_driver is None or self.run_id is None: + # Connect and create run self.grpc_driver = GrpcDriver( driver_service_address=self.addr, certificates=self.certificates ) self.grpc_driver.connect() - res = self.grpc_driver.create_workload(CreateWorkloadRequest()) - self.workload_id = res.workload_id + res = self.grpc_driver.create_run(CreateRunRequest()) + self.run_id = res.run_id - return self.grpc_driver, self.workload_id + return self.grpc_driver, self.run_id def get_nodes(self) -> List[Node]: """Get node IDs.""" - grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() + grpc_driver, run_id = self._get_grpc_driver_and_run_id() # Call GrpcDriver method - res = grpc_driver.get_nodes(GetNodesRequest(workload_id=workload_id)) + res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id)) return list(res.nodes) def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: """Schedule tasks.""" - grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() + grpc_driver, run_id = self._get_grpc_driver_and_run_id() - # Set workload_id + # Set run_id for task_ins in task_ins_list: - task_ins.workload_id = workload_id + task_ins.run_id = run_id # Call GrpcDriver method res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) @@ -92,7 +92,7 @@ def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: def pull_task_res(self, task_ids: Iterable[str]) -> List[TaskRes]: """Get task results.""" - grpc_driver, _ = self._get_grpc_driver_and_workload_id() + grpc_driver, _ = self._get_grpc_driver_and_run_id() # Call GrpcDriver method res = grpc_driver.pull_task_res( diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index 6d60fc49159b..8b2e51c17ea0 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -20,7 +20,12 @@ from flwr import common from flwr.common import serde -from flwr.proto import driver_pb2, node_pb2, task_pb2, transport_pb2 +from flwr.proto import ( # pylint: disable=E0611 + driver_pb2, + node_pb2, + task_pb2, + transport_pb2, +) from flwr.server.client_proxy import ClientProxy from .grpc_driver import GrpcDriver @@ -31,20 +36,18 @@ class DriverClientProxy(ClientProxy): """Flower client proxy which delegates work using the Driver API.""" - def __init__( - self, node_id: int, driver: GrpcDriver, anonymous: bool, workload_id: int - ): + def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: int): super().__init__(str(node_id)) self.node_id = node_id self.driver = driver - self.workload_id = workload_id + self.run_id = run_id self.anonymous = anonymous def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float] ) -> common.GetPropertiesRes: """Return client's properties.""" - server_message_proto: transport_pb2.ServerMessage = ( + server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 serde.server_message_to_proto( server_message=common.ServerMessage(get_properties_ins=ins) ) @@ -58,7 +61,7 @@ def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float] ) -> common.GetParametersRes: """Return the current local model parameters.""" - server_message_proto: transport_pb2.ServerMessage = ( + server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 serde.server_message_to_proto( server_message=common.ServerMessage(get_parameters_ins=ins) ) @@ -70,7 +73,7 @@ def get_parameters( def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( + server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 serde.server_message_to_proto( server_message=common.ServerMessage(fit_ins=ins) ) @@ -84,7 +87,7 @@ def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( + server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 serde.server_message_to_proto( server_message=common.ServerMessage(evaluate_ins=ins) ) @@ -101,25 +104,29 @@ def reconnect( return common.DisconnectRes(reason="") # Nothing to do here (yet) def _send_receive_msg( - self, server_message: transport_pb2.ServerMessage, timeout: Optional[float] - ) -> transport_pb2.ClientMessage: - task_ins = task_pb2.TaskIns( + self, + server_message: transport_pb2.ServerMessage, # pylint: disable=E1101 + timeout: Optional[float], + ) -> transport_pb2.ClientMessage: # pylint: disable=E1101 + task_ins = task_pb2.TaskIns( # pylint: disable=E1101 task_id="", group_id="", - workload_id=self.workload_id, - task=task_pb2.Task( - producer=node_pb2.Node( + run_id=self.run_id, + task=task_pb2.Task( # pylint: disable=E1101 + producer=node_pb2.Node( # pylint: disable=E1101 node_id=0, anonymous=True, ), - consumer=node_pb2.Node( + consumer=node_pb2.Node( # pylint: disable=E1101 node_id=self.node_id, anonymous=self.anonymous, ), legacy_server_message=server_message, ), ) - push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=[task_ins]) + push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 + task_ins_list=[task_ins] + ) # Send TaskIns to Driver API push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req) @@ -135,15 +142,15 @@ def _send_receive_msg( start_time = time.time() while True: - pull_task_res_req = driver_pb2.PullTaskResRequest( - node=node_pb2.Node(node_id=0, anonymous=True), + pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101 + node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101 task_ids=[task_id], ) # Ask Driver API for TaskRes pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req) - task_res_list: List[task_pb2.TaskRes] = list( + task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101 pull_task_res_res.task_res_list ) if len(task_res_list) == 1: diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index 82b5b46d7810..d3cab152e4db 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -23,8 +23,12 @@ import flwr from flwr.common.typing import Config, GetParametersIns from flwr.driver.driver_client_proxy import DriverClientProxy -from flwr.proto import driver_pb2, node_pb2, task_pb2 -from flwr.proto.transport_pb2 import ClientMessage, Parameters, Scalar +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + Parameters, + Scalar, +) MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") @@ -37,34 +41,42 @@ class DriverClientProxyTestCase(unittest.TestCase): def setUp(self) -> None: """Set up mocks for tests.""" self.driver = MagicMock() - self.driver.get_nodes.return_value = driver_pb2.GetNodesResponse( - nodes=[node_pb2.Node(node_id=1, anonymous=False)] + self.driver.get_nodes.return_value = ( + driver_pb2.GetNodesResponse( # pylint: disable=E1101 + nodes=[ + node_pb2.Node(node_id=1, anonymous=False) # pylint: disable=E1101 + ] + ) ) def test_get_properties(self) -> None: """Test positive case.""" # Prepare - self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse( - task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + self.driver.push_task_ins.return_value = ( + driver_pb2.PushTaskInsResponse( # pylint: disable=E1101 + task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + ) ) - self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse( - task_res_list=[ - task_pb2.TaskRes( - task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", - workload_id=0, - task=task_pb2.Task( - legacy_client_message=ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes( - properties=CLIENT_PROPERTIES + self.driver.pull_task_res.return_value = ( + driver_pb2.PullTaskResResponse( # pylint: disable=E1101 + task_res_list=[ + task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + legacy_client_message=ClientMessage( + get_properties_res=ClientMessage.GetPropertiesRes( + properties=CLIENT_PROPERTIES + ) ) - ) - ), - ) - ] + ), + ) + ] + ) ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 + node_id=1, driver=self.driver, anonymous=True, run_id=0 ) request_properties: Config = {"tensor_type": "str"} ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns( @@ -80,27 +92,31 @@ def test_get_properties(self) -> None: def test_get_parameters(self) -> None: """Test positive case.""" # Prepare - self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse( - task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + self.driver.push_task_ins.return_value = ( + driver_pb2.PushTaskInsResponse( # pylint: disable=E1101 + task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + ) ) - self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse( - task_res_list=[ - task_pb2.TaskRes( - task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", - workload_id=0, - task=task_pb2.Task( - legacy_client_message=ClientMessage( - get_parameters_res=ClientMessage.GetParametersRes( - parameters=MESSAGE_PARAMETERS, + self.driver.pull_task_res.return_value = ( + driver_pb2.PullTaskResResponse( # pylint: disable=E1101 + task_res_list=[ + task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + legacy_client_message=ClientMessage( + get_parameters_res=ClientMessage.GetParametersRes( + parameters=MESSAGE_PARAMETERS, + ) ) - ) - ), - ) - ] + ), + ) + ] + ) ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 + node_id=1, driver=self.driver, anonymous=True, run_id=0 ) get_parameters_ins = GetParametersIns(config={}) @@ -115,28 +131,32 @@ def test_get_parameters(self) -> None: def test_fit(self) -> None: """Test positive case.""" # Prepare - self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse( - task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + self.driver.push_task_ins.return_value = ( + driver_pb2.PushTaskInsResponse( # pylint: disable=E1101 + task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + ) ) - self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse( - task_res_list=[ - task_pb2.TaskRes( - task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", - workload_id=0, - task=task_pb2.Task( - legacy_client_message=ClientMessage( - fit_res=ClientMessage.FitRes( - parameters=MESSAGE_PARAMETERS, - num_examples=10, + self.driver.pull_task_res.return_value = ( + driver_pb2.PullTaskResResponse( # pylint: disable=E1101 + task_res_list=[ + task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + legacy_client_message=ClientMessage( + fit_res=ClientMessage.FitRes( + parameters=MESSAGE_PARAMETERS, + num_examples=10, + ) ) - ) - ), - ) - ] + ), + ) + ] + ) ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 + node_id=1, driver=self.driver, anonymous=True, run_id=0 ) parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))]) ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) @@ -152,27 +172,31 @@ def test_fit(self) -> None: def test_evaluate(self) -> None: """Test positive case.""" # Prepare - self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse( - task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + self.driver.push_task_ins.return_value = ( + driver_pb2.PushTaskInsResponse( # pylint: disable=E1101 + task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"] + ) ) - self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse( - task_res_list=[ - task_pb2.TaskRes( - task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", - group_id="", - workload_id=0, - task=task_pb2.Task( - legacy_client_message=ClientMessage( - evaluate_res=ClientMessage.EvaluateRes( - loss=0.0, num_examples=0 + self.driver.pull_task_res.return_value = ( + driver_pb2.PullTaskResResponse( # pylint: disable=E1101 + task_res_list=[ + task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + legacy_client_message=ClientMessage( + evaluate_res=ClientMessage.EvaluateRes( + loss=0.0, num_examples=0 + ) ) - ) - ), - ) - ] + ), + ) + ] + ) ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 + node_id=1, driver=self.driver, anonymous=True, run_id=0 ) parameters = flwr.common.Parameters(tensors=[], tensor_type="np") evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py index 820018788a8f..1854a92b5ebe 100644 --- a/src/py/flwr/driver/driver_test.py +++ b/src/py/flwr/driver/driver_test.py @@ -19,12 +19,12 @@ from unittest.mock import Mock, patch from flwr.driver.driver import Driver -from flwr.proto.driver_pb2 import ( +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 GetNodesRequest, PullTaskResRequest, PushTaskInsRequest, ) -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 class TestDriver(unittest.TestCase): @@ -33,9 +33,9 @@ class TestDriver(unittest.TestCase): def setUp(self) -> None: """Initialize mock GrpcDriver and Driver instance before each test.""" mock_response = Mock() - mock_response.workload_id = 61016 + mock_response.run_id = 61016 self.mock_grpc_driver = Mock() - self.mock_grpc_driver.create_workload.return_value = mock_response + self.mock_grpc_driver.create_run.return_value = mock_response self.patcher = patch( "flwr.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver ) @@ -47,27 +47,27 @@ def tearDown(self) -> None: self.patcher.stop() def test_check_and_init_grpc_driver_already_initialized(self) -> None: - """Test that GrpcDriver doesn't initialize if workload is created.""" + """Test that GrpcDriver doesn't initialize if run is created.""" # Prepare self.driver.grpc_driver = self.mock_grpc_driver - self.driver.workload_id = 61016 + self.driver.run_id = 61016 # Execute # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() + self.driver._get_grpc_driver_and_run_id() # Assert self.mock_grpc_driver.connect.assert_not_called() def test_check_and_init_grpc_driver_needs_initialization(self) -> None: - """Test GrpcDriver initialization when workload is not created.""" + """Test GrpcDriver initialization when run is not created.""" # Execute # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() + self.driver._get_grpc_driver_and_run_id() # Assert self.mock_grpc_driver.connect.assert_called_once() - self.assertEqual(self.driver.workload_id, 61016) + self.assertEqual(self.driver.run_id, 61016) def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -85,7 +85,7 @@ def test_get_nodes(self) -> None: self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) - self.assertEqual(args[0].workload_id, 61016) + self.assertEqual(args[0].run_id, 61016) self.assertEqual(nodes, mock_response.nodes) def test_push_task_ins(self) -> None: @@ -107,7 +107,7 @@ def test_push_task_ins(self) -> None: self.assertIsInstance(args[0], PushTaskInsRequest) self.assertEqual(task_ids, mock_response.task_ids) for task_ins in args[0].task_ins_list: - self.assertEqual(task_ins.workload_id, 61016) + self.assertEqual(task_ins.run_id, 61016) def test_pull_task_res_with_given_task_ids(self) -> None: """Test pulling task results with specific task IDs.""" @@ -136,9 +136,10 @@ def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" # Prepare # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() + self.driver._get_grpc_driver_and_run_id() # Execute + # pylint: disable-next=unnecessary-dunder-call self.driver.__del__() # Assert @@ -147,6 +148,7 @@ def test_del_with_initialized_driver(self) -> None: def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" # Execute + # pylint: disable-next=unnecessary-dunder-call self.driver.__del__() # Assert diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py index 7dd0a0f501c5..23d449790092 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/driver/grpc_driver.py @@ -23,9 +23,9 @@ from flwr.common import EventType, event from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, - CreateWorkloadResponse, +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -33,7 +33,7 @@ PushTaskInsRequest, PushTaskInsResponse, ) -from flwr.proto.driver_pb2_grpc import DriverStub +from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611 DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -84,15 +84,15 @@ def disconnect(self) -> None: channel.close() log(INFO, "[Driver] Disconnected") - def create_workload(self, req: CreateWorkloadRequest) -> CreateWorkloadResponse: - """Request for workload ID.""" + def create_run(self, req: CreateRunRequest) -> CreateRunResponse: + """Request for run ID.""" # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API - res: CreateWorkloadResponse = self.stub.CreateWorkload(request=req) + res: CreateRunResponse = self.stub.CreateRun(request=req) return res def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: @@ -100,7 +100,7 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: GetNodesResponse = self.stub.GetNodes(request=req) @@ -111,7 +111,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) @@ -122,7 +122,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index c138507e03e9..fe9c33da0fa9 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/driver.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -16,94 +16,29 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x17\n\x15\x43reateWorkloadRequest\"-\n\x16\x43reateWorkloadResponse\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x12\"&\n\x0fGetNodesRequest\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xd0\x02\n\x06\x44river\x12Y\n\x0e\x43reateWorkload\x12!.flwr.proto.CreateWorkloadRequest\x1a\".flwr.proto.CreateWorkloadResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x12\n\x10\x43reateRunRequest\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc1\x02\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') - - -_CREATEWORKLOADREQUEST = DESCRIPTOR.message_types_by_name['CreateWorkloadRequest'] -_CREATEWORKLOADRESPONSE = DESCRIPTOR.message_types_by_name['CreateWorkloadResponse'] -_GETNODESREQUEST = DESCRIPTOR.message_types_by_name['GetNodesRequest'] -_GETNODESRESPONSE = DESCRIPTOR.message_types_by_name['GetNodesResponse'] -_PUSHTASKINSREQUEST = DESCRIPTOR.message_types_by_name['PushTaskInsRequest'] -_PUSHTASKINSRESPONSE = DESCRIPTOR.message_types_by_name['PushTaskInsResponse'] -_PULLTASKRESREQUEST = DESCRIPTOR.message_types_by_name['PullTaskResRequest'] -_PULLTASKRESRESPONSE = DESCRIPTOR.message_types_by_name['PullTaskResResponse'] -CreateWorkloadRequest = _reflection.GeneratedProtocolMessageType('CreateWorkloadRequest', (_message.Message,), { - 'DESCRIPTOR' : _CREATEWORKLOADREQUEST, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateWorkloadRequest) - }) -_sym_db.RegisterMessage(CreateWorkloadRequest) - -CreateWorkloadResponse = _reflection.GeneratedProtocolMessageType('CreateWorkloadResponse', (_message.Message,), { - 'DESCRIPTOR' : _CREATEWORKLOADRESPONSE, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateWorkloadResponse) - }) -_sym_db.RegisterMessage(CreateWorkloadResponse) - -GetNodesRequest = _reflection.GeneratedProtocolMessageType('GetNodesRequest', (_message.Message,), { - 'DESCRIPTOR' : _GETNODESREQUEST, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.GetNodesRequest) - }) -_sym_db.RegisterMessage(GetNodesRequest) - -GetNodesResponse = _reflection.GeneratedProtocolMessageType('GetNodesResponse', (_message.Message,), { - 'DESCRIPTOR' : _GETNODESRESPONSE, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.GetNodesResponse) - }) -_sym_db.RegisterMessage(GetNodesResponse) - -PushTaskInsRequest = _reflection.GeneratedProtocolMessageType('PushTaskInsRequest', (_message.Message,), { - 'DESCRIPTOR' : _PUSHTASKINSREQUEST, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskInsRequest) - }) -_sym_db.RegisterMessage(PushTaskInsRequest) - -PushTaskInsResponse = _reflection.GeneratedProtocolMessageType('PushTaskInsResponse', (_message.Message,), { - 'DESCRIPTOR' : _PUSHTASKINSRESPONSE, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskInsResponse) - }) -_sym_db.RegisterMessage(PushTaskInsResponse) - -PullTaskResRequest = _reflection.GeneratedProtocolMessageType('PullTaskResRequest', (_message.Message,), { - 'DESCRIPTOR' : _PULLTASKRESREQUEST, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PullTaskResRequest) - }) -_sym_db.RegisterMessage(PullTaskResRequest) - -PullTaskResResponse = _reflection.GeneratedProtocolMessageType('PullTaskResResponse', (_message.Message,), { - 'DESCRIPTOR' : _PULLTASKRESRESPONSE, - '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PullTaskResResponse) - }) -_sym_db.RegisterMessage(PullTaskResResponse) - -_DRIVER = DESCRIPTOR.services_by_name['Driver'] +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _CREATEWORKLOADREQUEST._serialized_start=85 - _CREATEWORKLOADREQUEST._serialized_end=108 - _CREATEWORKLOADRESPONSE._serialized_start=110 - _CREATEWORKLOADRESPONSE._serialized_end=155 - _GETNODESREQUEST._serialized_start=157 - _GETNODESREQUEST._serialized_end=195 - _GETNODESRESPONSE._serialized_start=197 - _GETNODESRESPONSE._serialized_end=248 - _PUSHTASKINSREQUEST._serialized_start=250 - _PUSHTASKINSREQUEST._serialized_end=314 - _PUSHTASKINSRESPONSE._serialized_start=316 - _PUSHTASKINSRESPONSE._serialized_end=355 - _PULLTASKRESREQUEST._serialized_start=357 - _PULLTASKRESREQUEST._serialized_end=427 - _PULLTASKRESRESPONSE._serialized_start=429 - _PULLTASKRESRESPONSE._serialized_end=494 - _DRIVER._serialized_start=497 - _DRIVER._serialized_end=833 + _globals['_CREATERUNREQUEST']._serialized_start=85 + _globals['_CREATERUNREQUEST']._serialized_end=103 + _globals['_CREATERUNRESPONSE']._serialized_start=105 + _globals['_CREATERUNRESPONSE']._serialized_end=140 + _globals['_GETNODESREQUEST']._serialized_start=142 + _globals['_GETNODESREQUEST']._serialized_end=175 + _globals['_GETNODESRESPONSE']._serialized_start=177 + _globals['_GETNODESRESPONSE']._serialized_end=228 + _globals['_PUSHTASKINSREQUEST']._serialized_start=230 + _globals['_PUSHTASKINSREQUEST']._serialized_end=294 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=296 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=335 + _globals['_PULLTASKRESREQUEST']._serialized_start=337 + _globals['_PULLTASKRESREQUEST']._serialized_end=407 + _globals['_PULLTASKRESRESPONSE']._serialized_start=409 + _globals['_PULLTASKRESRESPONSE']._serialized_end=474 + _globals['_DRIVER']._serialized_start=477 + _globals['_DRIVER']._serialized_end=798 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 8b940972cb6d..8dc254a55e8c 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -13,34 +13,34 @@ import typing_extensions DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -class CreateWorkloadRequest(google.protobuf.message.Message): - """CreateWorkload""" +class CreateRunRequest(google.protobuf.message.Message): + """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor def __init__(self, ) -> None: ... -global___CreateWorkloadRequest = CreateWorkloadRequest +global___CreateRunRequest = CreateRunRequest -class CreateWorkloadResponse(google.protobuf.message.Message): +class CreateRunResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int def __init__(self, *, - workload_id: builtins.int = ..., + run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... -global___CreateWorkloadResponse = CreateWorkloadResponse + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... +global___CreateRunResponse = CreateRunResponse class GetNodesRequest(google.protobuf.message.Message): """GetNodes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor - WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int def __init__(self, *, - workload_id: builtins.int = ..., + run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... global___GetNodesRequest = GetNodesRequest class GetNodesResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/driver_pb2_grpc.py b/src/py/flwr/proto/driver_pb2_grpc.py index ea33b843d945..ac6815023ebd 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.py +++ b/src/py/flwr/proto/driver_pb2_grpc.py @@ -14,10 +14,10 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.CreateWorkload = channel.unary_unary( - '/flwr.proto.Driver/CreateWorkload', - request_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString, + self.CreateRun = channel.unary_unary( + '/flwr.proto.Driver/CreateRun', + request_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, ) self.GetNodes = channel.unary_unary( '/flwr.proto.Driver/GetNodes', @@ -39,8 +39,8 @@ def __init__(self, channel): class DriverServicer(object): """Missing associated documentation comment in .proto file.""" - def CreateWorkload(self, request, context): - """Request workload_id + def CreateRun(self, request, context): + """Request run_id """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -70,10 +70,10 @@ def PullTaskRes(self, request, context): def add_DriverServicer_to_server(servicer, server): rpc_method_handlers = { - 'CreateWorkload': grpc.unary_unary_rpc_method_handler( - servicer.CreateWorkload, - request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.FromString, - response_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.SerializeToString, + 'CreateRun': grpc.unary_unary_rpc_method_handler( + servicer.CreateRun, + request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.SerializeToString, ), 'GetNodes': grpc.unary_unary_rpc_method_handler( servicer.GetNodes, @@ -101,7 +101,7 @@ class Driver(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def CreateWorkload(request, + def CreateRun(request, target, options=(), channel_credentials=None, @@ -111,9 +111,9 @@ def CreateWorkload(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateWorkload', - flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString, - flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString, + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateRun', + flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, + flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/driver_pb2_grpc.pyi b/src/py/flwr/proto/driver_pb2_grpc.pyi index 1b10d71e943d..43cf45f39b25 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.pyi +++ b/src/py/flwr/proto/driver_pb2_grpc.pyi @@ -8,10 +8,10 @@ import grpc class DriverStub: def __init__(self, channel: grpc.Channel) -> None: ... - CreateWorkload: grpc.UnaryUnaryMultiCallable[ - flwr.proto.driver_pb2.CreateWorkloadRequest, - flwr.proto.driver_pb2.CreateWorkloadResponse] - """Request workload_id""" + CreateRun: grpc.UnaryUnaryMultiCallable[ + flwr.proto.driver_pb2.CreateRunRequest, + flwr.proto.driver_pb2.CreateRunResponse] + """Request run_id""" GetNodes: grpc.UnaryUnaryMultiCallable[ flwr.proto.driver_pb2.GetNodesRequest, @@ -31,11 +31,11 @@ class DriverStub: class DriverServicer(metaclass=abc.ABCMeta): @abc.abstractmethod - def CreateWorkload(self, - request: flwr.proto.driver_pb2.CreateWorkloadRequest, + def CreateRun(self, + request: flwr.proto.driver_pb2.CreateRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.driver_pb2.CreateWorkloadResponse: - """Request workload_id""" + ) -> flwr.proto.driver_pb2.CreateRunResponse: + """Request run_id""" pass @abc.abstractmethod diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index e86a53e2139e..e8443c296f0c 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/fleet.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -18,115 +18,33 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x02\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3') - - -_CREATENODEREQUEST = DESCRIPTOR.message_types_by_name['CreateNodeRequest'] -_CREATENODERESPONSE = DESCRIPTOR.message_types_by_name['CreateNodeResponse'] -_DELETENODEREQUEST = DESCRIPTOR.message_types_by_name['DeleteNodeRequest'] -_DELETENODERESPONSE = DESCRIPTOR.message_types_by_name['DeleteNodeResponse'] -_PULLTASKINSREQUEST = DESCRIPTOR.message_types_by_name['PullTaskInsRequest'] -_PULLTASKINSRESPONSE = DESCRIPTOR.message_types_by_name['PullTaskInsResponse'] -_PUSHTASKRESREQUEST = DESCRIPTOR.message_types_by_name['PushTaskResRequest'] -_PUSHTASKRESRESPONSE = DESCRIPTOR.message_types_by_name['PushTaskResResponse'] -_PUSHTASKRESRESPONSE_RESULTSENTRY = _PUSHTASKRESRESPONSE.nested_types_by_name['ResultsEntry'] -_RECONNECT = DESCRIPTOR.message_types_by_name['Reconnect'] -CreateNodeRequest = _reflection.GeneratedProtocolMessageType('CreateNodeRequest', (_message.Message,), { - 'DESCRIPTOR' : _CREATENODEREQUEST, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateNodeRequest) - }) -_sym_db.RegisterMessage(CreateNodeRequest) - -CreateNodeResponse = _reflection.GeneratedProtocolMessageType('CreateNodeResponse', (_message.Message,), { - 'DESCRIPTOR' : _CREATENODERESPONSE, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateNodeResponse) - }) -_sym_db.RegisterMessage(CreateNodeResponse) - -DeleteNodeRequest = _reflection.GeneratedProtocolMessageType('DeleteNodeRequest', (_message.Message,), { - 'DESCRIPTOR' : _DELETENODEREQUEST, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.DeleteNodeRequest) - }) -_sym_db.RegisterMessage(DeleteNodeRequest) - -DeleteNodeResponse = _reflection.GeneratedProtocolMessageType('DeleteNodeResponse', (_message.Message,), { - 'DESCRIPTOR' : _DELETENODERESPONSE, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.DeleteNodeResponse) - }) -_sym_db.RegisterMessage(DeleteNodeResponse) - -PullTaskInsRequest = _reflection.GeneratedProtocolMessageType('PullTaskInsRequest', (_message.Message,), { - 'DESCRIPTOR' : _PULLTASKINSREQUEST, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PullTaskInsRequest) - }) -_sym_db.RegisterMessage(PullTaskInsRequest) - -PullTaskInsResponse = _reflection.GeneratedProtocolMessageType('PullTaskInsResponse', (_message.Message,), { - 'DESCRIPTOR' : _PULLTASKINSRESPONSE, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PullTaskInsResponse) - }) -_sym_db.RegisterMessage(PullTaskInsResponse) - -PushTaskResRequest = _reflection.GeneratedProtocolMessageType('PushTaskResRequest', (_message.Message,), { - 'DESCRIPTOR' : _PUSHTASKRESREQUEST, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskResRequest) - }) -_sym_db.RegisterMessage(PushTaskResRequest) - -PushTaskResResponse = _reflection.GeneratedProtocolMessageType('PushTaskResResponse', (_message.Message,), { - - 'ResultsEntry' : _reflection.GeneratedProtocolMessageType('ResultsEntry', (_message.Message,), { - 'DESCRIPTOR' : _PUSHTASKRESRESPONSE_RESULTSENTRY, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskResResponse.ResultsEntry) - }) - , - 'DESCRIPTOR' : _PUSHTASKRESRESPONSE, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.PushTaskResResponse) - }) -_sym_db.RegisterMessage(PushTaskResResponse) -_sym_db.RegisterMessage(PushTaskResResponse.ResultsEntry) - -Reconnect = _reflection.GeneratedProtocolMessageType('Reconnect', (_message.Message,), { - 'DESCRIPTOR' : _RECONNECT, - '__module__' : 'flwr.proto.fleet_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Reconnect) - }) -_sym_db.RegisterMessage(Reconnect) - -_FLEET = DESCRIPTOR.services_by_name['Fleet'] +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.fleet_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _PUSHTASKRESRESPONSE_RESULTSENTRY._options = None - _PUSHTASKRESRESPONSE_RESULTSENTRY._serialized_options = b'8\001' - _CREATENODEREQUEST._serialized_start=84 - _CREATENODEREQUEST._serialized_end=103 - _CREATENODERESPONSE._serialized_start=105 - _CREATENODERESPONSE._serialized_end=157 - _DELETENODEREQUEST._serialized_start=159 - _DELETENODEREQUEST._serialized_end=210 - _DELETENODERESPONSE._serialized_start=212 - _DELETENODERESPONSE._serialized_end=232 - _PULLTASKINSREQUEST._serialized_start=234 - _PULLTASKINSREQUEST._serialized_end=304 - _PULLTASKINSRESPONSE._serialized_start=306 - _PULLTASKINSRESPONSE._serialized_end=413 - _PUSHTASKRESREQUEST._serialized_start=415 - _PUSHTASKRESREQUEST._serialized_end=479 - _PUSHTASKRESRESPONSE._serialized_start=482 - _PUSHTASKRESRESPONSE._serialized_end=656 - _PUSHTASKRESRESPONSE_RESULTSENTRY._serialized_start=610 - _PUSHTASKRESRESPONSE_RESULTSENTRY._serialized_end=656 - _RECONNECT._serialized_start=658 - _RECONNECT._serialized_end=688 - _FLEET._serialized_start=691 - _FLEET._serialized_end=1020 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._options = None + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001' + _globals['_CREATENODEREQUEST']._serialized_start=84 + _globals['_CREATENODEREQUEST']._serialized_end=103 + _globals['_CREATENODERESPONSE']._serialized_start=105 + _globals['_CREATENODERESPONSE']._serialized_end=157 + _globals['_DELETENODEREQUEST']._serialized_start=159 + _globals['_DELETENODEREQUEST']._serialized_end=210 + _globals['_DELETENODERESPONSE']._serialized_start=212 + _globals['_DELETENODERESPONSE']._serialized_end=232 + _globals['_PULLTASKINSREQUEST']._serialized_start=234 + _globals['_PULLTASKINSREQUEST']._serialized_end=304 + _globals['_PULLTASKINSRESPONSE']._serialized_start=306 + _globals['_PULLTASKINSRESPONSE']._serialized_end=413 + _globals['_PUSHTASKRESREQUEST']._serialized_start=415 + _globals['_PUSHTASKRESREQUEST']._serialized_end=479 + _globals['_PUSHTASKRESRESPONSE']._serialized_start=482 + _globals['_PUSHTASKRESRESPONSE']._serialized_end=656 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=610 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=656 + _globals['_RECONNECT']._serialized_start=658 + _globals['_RECONNECT']._serialized_end=688 + _globals['_FLEET']._serialized_start=691 + _globals['_FLEET']._serialized_end=1020 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/node_pb2.py b/src/py/flwr/proto/node_pb2.py index 9d91900d8f53..b300f2c562c2 100644 --- a/src/py/flwr/proto/node_pb2.py +++ b/src/py/flwr/proto/node_pb2.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/node.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -16,19 +16,11 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"*\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x11\n\tanonymous\x18\x02 \x01(\x08\x62\x06proto3') - - -_NODE = DESCRIPTOR.message_types_by_name['Node'] -Node = _reflection.GeneratedProtocolMessageType('Node', (_message.Message,), { - 'DESCRIPTOR' : _NODE, - '__module__' : 'flwr.proto.node_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Node) - }) -_sym_db.RegisterMessage(Node) - +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.node_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _NODE._serialized_start=37 - _NODE._serialized_end=79 + _globals['_NODE']._serialized_start=37 + _globals['_NODE']._serialized_end=79 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.py b/src/py/flwr/proto/recordset_pb2.py new file mode 100644 index 000000000000..4134511f1f53 --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/recordset.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.recordset_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_METRICSRECORD_DATAENTRY']._options = None + _globals['_METRICSRECORD_DATAENTRY']._serialized_options = b'8\001' + _globals['_CONFIGSRECORD_DATAENTRY']._options = None + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_options = b'8\001' + _globals['_DOUBLELIST']._serialized_start=42 + _globals['_DOUBLELIST']._serialized_end=68 + _globals['_SINT64LIST']._serialized_start=70 + _globals['_SINT64LIST']._serialized_end=96 + _globals['_BOOLLIST']._serialized_start=98 + _globals['_BOOLLIST']._serialized_end=122 + _globals['_STRINGLIST']._serialized_start=124 + _globals['_STRINGLIST']._serialized_end=150 + _globals['_BYTESLIST']._serialized_start=152 + _globals['_BYTESLIST']._serialized_end=177 + _globals['_ARRAY']._serialized_start=179 + _globals['_ARRAY']._serialized_end=245 + _globals['_METRICSRECORDVALUE']._serialized_start=248 + _globals['_METRICSRECORDVALUE']._serialized_end=407 + _globals['_CONFIGSRECORDVALUE']._serialized_start=410 + _globals['_CONFIGSRECORDVALUE']._serialized_end=755 + _globals['_PARAMETERSRECORD']._serialized_start=757 + _globals['_PARAMETERSRECORD']._serialized_end=834 + _globals['_METRICSRECORD']._serialized_start=837 + _globals['_METRICSRECORD']._serialized_end=980 + _globals['_METRICSRECORD_DATAENTRY']._serialized_start=905 + _globals['_METRICSRECORD_DATAENTRY']._serialized_end=980 + _globals['_CONFIGSRECORD']._serialized_start=983 + _globals['_CONFIGSRECORD']._serialized_end=1126 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=1051 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=1126 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.pyi b/src/py/flwr/proto/recordset_pb2.pyi new file mode 100644 index 000000000000..1e9556de9ce6 --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2.pyi @@ -0,0 +1,240 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class DoubleList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.float]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___DoubleList = DoubleList + +class Sint64List(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.int]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___Sint64List = Sint64List + +class BoolList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BoolList = BoolList + +class StringList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[typing.Text]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___StringList = StringList + +class BytesList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BytesList = BytesList + +class Array(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DTYPE_FIELD_NUMBER: builtins.int + SHAPE_FIELD_NUMBER: builtins.int + STYPE_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + dtype: typing.Text + @property + def shape(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + stype: typing.Text + data: builtins.bytes + def __init__(self, + *, + dtype: typing.Text = ..., + shape: typing.Optional[typing.Iterable[builtins.int]] = ..., + stype: typing.Text = ..., + data: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data","dtype",b"dtype","shape",b"shape","stype",b"stype"]) -> None: ... +global___Array = Array + +class MetricsRecordValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DOUBLE_FIELD_NUMBER: builtins.int + SINT64_FIELD_NUMBER: builtins.int + DOUBLE_LIST_FIELD_NUMBER: builtins.int + SINT64_LIST_FIELD_NUMBER: builtins.int + double: builtins.float + """Single element""" + + sint64: builtins.int + @property + def double_list(self) -> global___DoubleList: + """List types""" + pass + @property + def sint64_list(self) -> global___Sint64List: ... + def __init__(self, + *, + double: builtins.float = ..., + sint64: builtins.int = ..., + double_list: typing.Optional[global___DoubleList] = ..., + sint64_list: typing.Optional[global___Sint64List] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","double_list","sint64_list"]]: ... +global___MetricsRecordValue = MetricsRecordValue + +class ConfigsRecordValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DOUBLE_FIELD_NUMBER: builtins.int + SINT64_FIELD_NUMBER: builtins.int + BOOL_FIELD_NUMBER: builtins.int + STRING_FIELD_NUMBER: builtins.int + BYTES_FIELD_NUMBER: builtins.int + DOUBLE_LIST_FIELD_NUMBER: builtins.int + SINT64_LIST_FIELD_NUMBER: builtins.int + BOOL_LIST_FIELD_NUMBER: builtins.int + STRING_LIST_FIELD_NUMBER: builtins.int + BYTES_LIST_FIELD_NUMBER: builtins.int + double: builtins.float + """Single element""" + + sint64: builtins.int + bool: builtins.bool + string: typing.Text + bytes: builtins.bytes + @property + def double_list(self) -> global___DoubleList: + """List types""" + pass + @property + def sint64_list(self) -> global___Sint64List: ... + @property + def bool_list(self) -> global___BoolList: ... + @property + def string_list(self) -> global___StringList: ... + @property + def bytes_list(self) -> global___BytesList: ... + def __init__(self, + *, + double: builtins.float = ..., + sint64: builtins.int = ..., + bool: builtins.bool = ..., + string: typing.Text = ..., + bytes: builtins.bytes = ..., + double_list: typing.Optional[global___DoubleList] = ..., + sint64_list: typing.Optional[global___Sint64List] = ..., + bool_list: typing.Optional[global___BoolList] = ..., + string_list: typing.Optional[global___StringList] = ..., + bytes_list: typing.Optional[global___BytesList] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","bool","string","bytes","double_list","sint64_list","bool_list","string_list","bytes_list"]]: ... +global___ConfigsRecordValue = ConfigsRecordValue + +class ParametersRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DATA_KEYS_FIELD_NUMBER: builtins.int + DATA_VALUES_FIELD_NUMBER: builtins.int + @property + def data_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + @property + def data_values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Array]: ... + def __init__(self, + *, + data_keys: typing.Optional[typing.Iterable[typing.Text]] = ..., + data_values: typing.Optional[typing.Iterable[global___Array]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data_keys",b"data_keys","data_values",b"data_values"]) -> None: ... +global___ParametersRecord = ParametersRecord + +class MetricsRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class DataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___MetricsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___MetricsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + DATA_FIELD_NUMBER: builtins.int + @property + def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricsRecordValue]: ... + def __init__(self, + *, + data: typing.Optional[typing.Mapping[typing.Text, global___MetricsRecordValue]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... +global___MetricsRecord = MetricsRecord + +class ConfigsRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class DataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ConfigsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ConfigsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + DATA_FIELD_NUMBER: builtins.int + @property + def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigsRecordValue]: ... + def __init__(self, + *, + data: typing.Optional[typing.Mapping[typing.Text, global___ConfigsRecordValue]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... +global___ConfigsRecord = ConfigsRecord diff --git a/src/py/flwr/proto/recordset_pb2_grpc.py b/src/py/flwr/proto/recordset_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/recordset_pb2_grpc.pyi b/src/py/flwr/proto/recordset_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 6d8cf8fd3656..963b07db94f8 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -1,148 +1,45 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/task.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 +from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xbe\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12)\n\x02sa\x18\x07 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"a\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"a\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') - - - -_TASK = DESCRIPTOR.message_types_by_name['Task'] -_TASKINS = DESCRIPTOR.message_types_by_name['TaskIns'] -_TASKRES = DESCRIPTOR.message_types_by_name['TaskRes'] -_VALUE = DESCRIPTOR.message_types_by_name['Value'] -_VALUE_DOUBLELIST = _VALUE.nested_types_by_name['DoubleList'] -_VALUE_SINT64LIST = _VALUE.nested_types_by_name['Sint64List'] -_VALUE_BOOLLIST = _VALUE.nested_types_by_name['BoolList'] -_VALUE_STRINGLIST = _VALUE.nested_types_by_name['StringList'] -_VALUE_BYTESLIST = _VALUE.nested_types_by_name['BytesList'] -_SECUREAGGREGATION = DESCRIPTOR.message_types_by_name['SecureAggregation'] -_SECUREAGGREGATION_NAMEDVALUESENTRY = _SECUREAGGREGATION.nested_types_by_name['NamedValuesEntry'] -Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), { - 'DESCRIPTOR' : _TASK, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Task) - }) -_sym_db.RegisterMessage(Task) - -TaskIns = _reflection.GeneratedProtocolMessageType('TaskIns', (_message.Message,), { - 'DESCRIPTOR' : _TASKINS, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.TaskIns) - }) -_sym_db.RegisterMessage(TaskIns) - -TaskRes = _reflection.GeneratedProtocolMessageType('TaskRes', (_message.Message,), { - 'DESCRIPTOR' : _TASKRES, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.TaskRes) - }) -_sym_db.RegisterMessage(TaskRes) - -Value = _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), { - - 'DoubleList' : _reflection.GeneratedProtocolMessageType('DoubleList', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_DOUBLELIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.DoubleList) - }) - , - - 'Sint64List' : _reflection.GeneratedProtocolMessageType('Sint64List', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_SINT64LIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.Sint64List) - }) - , - - 'BoolList' : _reflection.GeneratedProtocolMessageType('BoolList', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_BOOLLIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.BoolList) - }) - , - - 'StringList' : _reflection.GeneratedProtocolMessageType('StringList', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_STRINGLIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.StringList) - }) - , - - 'BytesList' : _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), { - 'DESCRIPTOR' : _VALUE_BYTESLIST, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value.BytesList) - }) - , - 'DESCRIPTOR' : _VALUE, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Value) - }) -_sym_db.RegisterMessage(Value) -_sym_db.RegisterMessage(Value.DoubleList) -_sym_db.RegisterMessage(Value.Sint64List) -_sym_db.RegisterMessage(Value.BoolList) -_sym_db.RegisterMessage(Value.StringList) -_sym_db.RegisterMessage(Value.BytesList) - -SecureAggregation = _reflection.GeneratedProtocolMessageType('SecureAggregation', (_message.Message,), { - - 'NamedValuesEntry' : _reflection.GeneratedProtocolMessageType('NamedValuesEntry', (_message.Message,), { - 'DESCRIPTOR' : _SECUREAGGREGATION_NAMEDVALUESENTRY, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.SecureAggregation.NamedValuesEntry) - }) - , - 'DESCRIPTOR' : _SECUREAGGREGATION, - '__module__' : 'flwr.proto.task_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.SecureAggregation) - }) -_sym_db.RegisterMessage(SecureAggregation) -_sym_db.RegisterMessage(SecureAggregation.NamedValuesEntry) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd1\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12)\n\x02sa\x18\x08 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xcc\x02\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _TASK.fields_by_name['legacy_server_message']._options = None - _TASK.fields_by_name['legacy_server_message']._serialized_options = b'\030\001' - _TASK.fields_by_name['legacy_client_message']._options = None - _TASK.fields_by_name['legacy_client_message']._serialized_options = b'\030\001' - _SECUREAGGREGATION_NAMEDVALUESENTRY._options = None - _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_options = b'8\001' - _TASK._serialized_start=89 - _TASK._serialized_end=407 - _TASKINS._serialized_start=409 - _TASKINS._serialized_end=506 - _TASKRES._serialized_start=508 - _TASKRES._serialized_end=605 - _VALUE._serialized_start=608 - _VALUE._serialized_end=1107 - _VALUE_DOUBLELIST._serialized_start=963 - _VALUE_DOUBLELIST._serialized_end=989 - _VALUE_SINT64LIST._serialized_start=991 - _VALUE_SINT64LIST._serialized_end=1017 - _VALUE_BOOLLIST._serialized_start=1019 - _VALUE_BOOLLIST._serialized_end=1043 - _VALUE_STRINGLIST._serialized_start=1045 - _VALUE_STRINGLIST._serialized_end=1071 - _VALUE_BYTESLIST._serialized_start=1073 - _VALUE_BYTESLIST._serialized_end=1098 - _SECUREAGGREGATION._serialized_start=1110 - _SECUREAGGREGATION._serialized_end=1270 - _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_start=1201 - _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_end=1270 + _globals['_TASK'].fields_by_name['legacy_server_message']._options = None + _globals['_TASK'].fields_by_name['legacy_server_message']._serialized_options = b'\030\001' + _globals['_TASK'].fields_by_name['legacy_client_message']._options = None + _globals['_TASK'].fields_by_name['legacy_client_message']._serialized_options = b'\030\001' + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._options = None + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_options = b'8\001' + _globals['_TASK']._serialized_start=117 + _globals['_TASK']._serialized_end=454 + _globals['_TASKINS']._serialized_start=456 + _globals['_TASKINS']._serialized_end=548 + _globals['_TASKRES']._serialized_start=550 + _globals['_TASKRES']._serialized_end=642 + _globals['_VALUE']._serialized_start=645 + _globals['_VALUE']._serialized_end=977 + _globals['_SECUREAGGREGATION']._serialized_start=980 + _globals['_SECUREAGGREGATION']._serialized_end=1140 + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_start=1071 + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_end=1140 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index 7cf96cb61edf..ebe69d05c974 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import flwr.proto.node_pb2 +import flwr.proto.recordset_pb2 import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -21,6 +22,7 @@ class Task(google.protobuf.message.Message): DELIVERED_AT_FIELD_NUMBER: builtins.int TTL_FIELD_NUMBER: builtins.int ANCESTRY_FIELD_NUMBER: builtins.int + TASK_TYPE_FIELD_NUMBER: builtins.int SA_FIELD_NUMBER: builtins.int LEGACY_SERVER_MESSAGE_FIELD_NUMBER: builtins.int LEGACY_CLIENT_MESSAGE_FIELD_NUMBER: builtins.int @@ -33,6 +35,7 @@ class Task(google.protobuf.message.Message): ttl: typing.Text @property def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + task_type: typing.Text @property def sa(self) -> global___SecureAggregation: ... @property @@ -47,115 +50,61 @@ class Task(google.protobuf.message.Message): delivered_at: typing.Text = ..., ttl: typing.Text = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., + task_type: typing.Text = ..., sa: typing.Optional[global___SecureAggregation] = ..., legacy_server_message: typing.Optional[flwr.proto.transport_pb2.ServerMessage] = ..., legacy_client_message: typing.Optional[flwr.proto.transport_pb2.ClientMessage] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa","ttl",b"ttl"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa","task_type",b"task_type","ttl",b"ttl"]) -> None: ... global___Task = Task class TaskIns(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TASK_ID_FIELD_NUMBER: builtins.int GROUP_ID_FIELD_NUMBER: builtins.int - WORKLOAD_ID_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: builtins.int + run_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: builtins.int = ..., + run_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","task",b"task","task_id",b"task_id","workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","run_id",b"run_id","task",b"task","task_id",b"task_id"]) -> None: ... global___TaskIns = TaskIns class TaskRes(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TASK_ID_FIELD_NUMBER: builtins.int GROUP_ID_FIELD_NUMBER: builtins.int - WORKLOAD_ID_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: builtins.int + run_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: builtins.int = ..., + run_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","task",b"task","task_id",b"task_id","workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","run_id",b"run_id","task",b"task","task_id",b"task_id"]) -> None: ... global___TaskRes = TaskRes class Value(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class DoubleList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.float]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class Sint64List(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.int]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class BoolList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class StringList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class BytesList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - DOUBLE_FIELD_NUMBER: builtins.int SINT64_FIELD_NUMBER: builtins.int BOOL_FIELD_NUMBER: builtins.int @@ -174,17 +123,17 @@ class Value(google.protobuf.message.Message): string: typing.Text bytes: builtins.bytes @property - def double_list(self) -> global___Value.DoubleList: + def double_list(self) -> flwr.proto.recordset_pb2.DoubleList: """List types""" pass @property - def sint64_list(self) -> global___Value.Sint64List: ... + def sint64_list(self) -> flwr.proto.recordset_pb2.Sint64List: ... @property - def bool_list(self) -> global___Value.BoolList: ... + def bool_list(self) -> flwr.proto.recordset_pb2.BoolList: ... @property - def string_list(self) -> global___Value.StringList: ... + def string_list(self) -> flwr.proto.recordset_pb2.StringList: ... @property - def bytes_list(self) -> global___Value.BytesList: ... + def bytes_list(self) -> flwr.proto.recordset_pb2.BytesList: ... def __init__(self, *, double: builtins.float = ..., @@ -192,11 +141,11 @@ class Value(google.protobuf.message.Message): bool: builtins.bool = ..., string: typing.Text = ..., bytes: builtins.bytes = ..., - double_list: typing.Optional[global___Value.DoubleList] = ..., - sint64_list: typing.Optional[global___Value.Sint64List] = ..., - bool_list: typing.Optional[global___Value.BoolList] = ..., - string_list: typing.Optional[global___Value.StringList] = ..., - bytes_list: typing.Optional[global___Value.BytesList] = ..., + double_list: typing.Optional[flwr.proto.recordset_pb2.DoubleList] = ..., + sint64_list: typing.Optional[flwr.proto.recordset_pb2.Sint64List] = ..., + bool_list: typing.Optional[flwr.proto.recordset_pb2.BoolList] = ..., + string_list: typing.Optional[flwr.proto.recordset_pb2.StringList] = ..., + bytes_list: typing.Optional[flwr.proto.recordset_pb2.BytesList] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... diff --git a/src/py/flwr/proto/transport_pb2.py b/src/py/flwr/proto/transport_pb2.py index 1e3785b0e312..d3aae72b63ab 100644 --- a/src/py/flwr/proto/transport_pb2.py +++ b/src/py/flwr/proto/transport_pb2.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: flwr/proto/transport.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -17,281 +16,73 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/transport.proto\x12\nflwr.proto\"9\n\x06Status\x12\x1e\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x10.flwr.proto.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\nParameters\x12\x0f\n\x07tensors\x18\x01 \x03(\x0c\x12\x13\n\x0btensor_type\x18\x02 \x01(\t\"\xba\x08\n\rServerMessage\x12?\n\rreconnect_ins\x18\x01 \x01(\x0b\x32&.flwr.proto.ServerMessage.ReconnectInsH\x00\x12H\n\x12get_properties_ins\x18\x02 \x01(\x0b\x32*.flwr.proto.ServerMessage.GetPropertiesInsH\x00\x12H\n\x12get_parameters_ins\x18\x03 \x01(\x0b\x32*.flwr.proto.ServerMessage.GetParametersInsH\x00\x12\x33\n\x07\x66it_ins\x18\x04 \x01(\x0b\x32 .flwr.proto.ServerMessage.FitInsH\x00\x12=\n\x0c\x65valuate_ins\x18\x05 \x01(\x0b\x32%.flwr.proto.ServerMessage.EvaluateInsH\x00\x1a\x1f\n\x0cReconnectIns\x12\x0f\n\x07seconds\x18\x01 \x01(\x03\x1a\x9d\x01\n\x10GetPropertiesIns\x12\x46\n\x06\x63onfig\x18\x01 \x03(\x0b\x32\x36.flwr.proto.ServerMessage.GetPropertiesIns.ConfigEntry\x1a\x41\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x9d\x01\n\x10GetParametersIns\x12\x46\n\x06\x63onfig\x18\x01 \x03(\x0b\x32\x36.flwr.proto.ServerMessage.GetParametersIns.ConfigEntry\x1a\x41\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\xb5\x01\n\x06\x46itIns\x12*\n\nparameters\x18\x01 \x01(\x0b\x32\x16.flwr.proto.Parameters\x12<\n\x06\x63onfig\x18\x02 \x03(\x0b\x32,.flwr.proto.ServerMessage.FitIns.ConfigEntry\x1a\x41\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\xbf\x01\n\x0b\x45valuateIns\x12*\n\nparameters\x18\x01 \x01(\x0b\x32\x16.flwr.proto.Parameters\x12\x41\n\x06\x63onfig\x18\x02 \x03(\x0b\x32\x31.flwr.proto.ServerMessage.EvaluateIns.ConfigEntry\x1a\x41\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x42\x05\n\x03msg\"\xa0\t\n\rClientMessage\x12\x41\n\x0e\x64isconnect_res\x18\x01 \x01(\x0b\x32\'.flwr.proto.ClientMessage.DisconnectResH\x00\x12H\n\x12get_properties_res\x18\x02 \x01(\x0b\x32*.flwr.proto.ClientMessage.GetPropertiesResH\x00\x12H\n\x12get_parameters_res\x18\x03 \x01(\x0b\x32*.flwr.proto.ClientMessage.GetParametersResH\x00\x12\x33\n\x07\x66it_res\x18\x04 \x01(\x0b\x32 .flwr.proto.ClientMessage.FitResH\x00\x12=\n\x0c\x65valuate_res\x18\x05 \x01(\x0b\x32%.flwr.proto.ClientMessage.EvaluateResH\x00\x1a\x33\n\rDisconnectRes\x12\"\n\x06reason\x18\x01 \x01(\x0e\x32\x12.flwr.proto.Reason\x1a\xcd\x01\n\x10GetPropertiesRes\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.flwr.proto.Status\x12N\n\nproperties\x18\x02 \x03(\x0b\x32:.flwr.proto.ClientMessage.GetPropertiesRes.PropertiesEntry\x1a\x45\n\x0fPropertiesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x62\n\x10GetParametersRes\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.flwr.proto.Status\x12*\n\nparameters\x18\x02 \x01(\x0b\x32\x16.flwr.proto.Parameters\x1a\xf2\x01\n\x06\x46itRes\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.flwr.proto.Status\x12*\n\nparameters\x18\x02 \x01(\x0b\x32\x16.flwr.proto.Parameters\x12\x14\n\x0cnum_examples\x18\x03 \x01(\x03\x12>\n\x07metrics\x18\x04 \x03(\x0b\x32-.flwr.proto.ClientMessage.FitRes.MetricsEntry\x1a\x42\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\xde\x01\n\x0b\x45valuateRes\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.flwr.proto.Status\x12\x0c\n\x04loss\x18\x02 \x01(\x02\x12\x14\n\x0cnum_examples\x18\x03 \x01(\x03\x12\x43\n\x07metrics\x18\x04 \x03(\x0b\x32\x32.flwr.proto.ClientMessage.EvaluateRes.MetricsEntry\x1a\x42\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x42\x05\n\x03msg\"i\n\x06Scalar\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x08 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\r \x01(\x08H\x00\x12\x10\n\x06string\x18\x0e \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x0f \x01(\x0cH\x00\x42\x08\n\x06scalar*\x8d\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\"\n\x1eGET_PROPERTIES_NOT_IMPLEMENTED\x10\x01\x12\"\n\x1eGET_PARAMETERS_NOT_IMPLEMENTED\x10\x02\x12\x17\n\x13\x46IT_NOT_IMPLEMENTED\x10\x03\x12\x1c\n\x18\x45VALUATE_NOT_IMPLEMENTED\x10\x04*[\n\x06Reason\x12\x0b\n\x07UNKNOWN\x10\x00\x12\r\n\tRECONNECT\x10\x01\x12\x16\n\x12POWER_DISCONNECTED\x10\x02\x12\x14\n\x10WIFI_UNAVAILABLE\x10\x03\x12\x07\n\x03\x41\x43K\x10\x04\x32S\n\rFlowerService\x12\x42\n\x04Join\x12\x19.flwr.proto.ClientMessage\x1a\x19.flwr.proto.ServerMessage\"\x00(\x01\x30\x01\x62\x06proto3') -_CODE = DESCRIPTOR.enum_types_by_name['Code'] -Code = enum_type_wrapper.EnumTypeWrapper(_CODE) -_REASON = DESCRIPTOR.enum_types_by_name['Reason'] -Reason = enum_type_wrapper.EnumTypeWrapper(_REASON) -OK = 0 -GET_PROPERTIES_NOT_IMPLEMENTED = 1 -GET_PARAMETERS_NOT_IMPLEMENTED = 2 -FIT_NOT_IMPLEMENTED = 3 -EVALUATE_NOT_IMPLEMENTED = 4 -UNKNOWN = 0 -RECONNECT = 1 -POWER_DISCONNECTED = 2 -WIFI_UNAVAILABLE = 3 -ACK = 4 - - -_STATUS = DESCRIPTOR.message_types_by_name['Status'] -_PARAMETERS = DESCRIPTOR.message_types_by_name['Parameters'] -_SERVERMESSAGE = DESCRIPTOR.message_types_by_name['ServerMessage'] -_SERVERMESSAGE_RECONNECTINS = _SERVERMESSAGE.nested_types_by_name['ReconnectIns'] -_SERVERMESSAGE_GETPROPERTIESINS = _SERVERMESSAGE.nested_types_by_name['GetPropertiesIns'] -_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY = _SERVERMESSAGE_GETPROPERTIESINS.nested_types_by_name['ConfigEntry'] -_SERVERMESSAGE_GETPARAMETERSINS = _SERVERMESSAGE.nested_types_by_name['GetParametersIns'] -_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY = _SERVERMESSAGE_GETPARAMETERSINS.nested_types_by_name['ConfigEntry'] -_SERVERMESSAGE_FITINS = _SERVERMESSAGE.nested_types_by_name['FitIns'] -_SERVERMESSAGE_FITINS_CONFIGENTRY = _SERVERMESSAGE_FITINS.nested_types_by_name['ConfigEntry'] -_SERVERMESSAGE_EVALUATEINS = _SERVERMESSAGE.nested_types_by_name['EvaluateIns'] -_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY = _SERVERMESSAGE_EVALUATEINS.nested_types_by_name['ConfigEntry'] -_CLIENTMESSAGE = DESCRIPTOR.message_types_by_name['ClientMessage'] -_CLIENTMESSAGE_DISCONNECTRES = _CLIENTMESSAGE.nested_types_by_name['DisconnectRes'] -_CLIENTMESSAGE_GETPROPERTIESRES = _CLIENTMESSAGE.nested_types_by_name['GetPropertiesRes'] -_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY = _CLIENTMESSAGE_GETPROPERTIESRES.nested_types_by_name['PropertiesEntry'] -_CLIENTMESSAGE_GETPARAMETERSRES = _CLIENTMESSAGE.nested_types_by_name['GetParametersRes'] -_CLIENTMESSAGE_FITRES = _CLIENTMESSAGE.nested_types_by_name['FitRes'] -_CLIENTMESSAGE_FITRES_METRICSENTRY = _CLIENTMESSAGE_FITRES.nested_types_by_name['MetricsEntry'] -_CLIENTMESSAGE_EVALUATERES = _CLIENTMESSAGE.nested_types_by_name['EvaluateRes'] -_CLIENTMESSAGE_EVALUATERES_METRICSENTRY = _CLIENTMESSAGE_EVALUATERES.nested_types_by_name['MetricsEntry'] -_SCALAR = DESCRIPTOR.message_types_by_name['Scalar'] -Status = _reflection.GeneratedProtocolMessageType('Status', (_message.Message,), { - 'DESCRIPTOR' : _STATUS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Status) - }) -_sym_db.RegisterMessage(Status) - -Parameters = _reflection.GeneratedProtocolMessageType('Parameters', (_message.Message,), { - 'DESCRIPTOR' : _PARAMETERS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Parameters) - }) -_sym_db.RegisterMessage(Parameters) - -ServerMessage = _reflection.GeneratedProtocolMessageType('ServerMessage', (_message.Message,), { - - 'ReconnectIns' : _reflection.GeneratedProtocolMessageType('ReconnectIns', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_RECONNECTINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.ReconnectIns) - }) - , - - 'GetPropertiesIns' : _reflection.GeneratedProtocolMessageType('GetPropertiesIns', (_message.Message,), { - - 'ConfigEntry' : _reflection.GeneratedProtocolMessageType('ConfigEntry', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.GetPropertiesIns.ConfigEntry) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE_GETPROPERTIESINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.GetPropertiesIns) - }) - , - - 'GetParametersIns' : _reflection.GeneratedProtocolMessageType('GetParametersIns', (_message.Message,), { - - 'ConfigEntry' : _reflection.GeneratedProtocolMessageType('ConfigEntry', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.GetParametersIns.ConfigEntry) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE_GETPARAMETERSINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.GetParametersIns) - }) - , - - 'FitIns' : _reflection.GeneratedProtocolMessageType('FitIns', (_message.Message,), { - - 'ConfigEntry' : _reflection.GeneratedProtocolMessageType('ConfigEntry', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_FITINS_CONFIGENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.FitIns.ConfigEntry) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE_FITINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.FitIns) - }) - , - - 'EvaluateIns' : _reflection.GeneratedProtocolMessageType('EvaluateIns', (_message.Message,), { - - 'ConfigEntry' : _reflection.GeneratedProtocolMessageType('ConfigEntry', (_message.Message,), { - 'DESCRIPTOR' : _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.EvaluateIns.ConfigEntry) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE_EVALUATEINS, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage.EvaluateIns) - }) - , - 'DESCRIPTOR' : _SERVERMESSAGE, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ServerMessage) - }) -_sym_db.RegisterMessage(ServerMessage) -_sym_db.RegisterMessage(ServerMessage.ReconnectIns) -_sym_db.RegisterMessage(ServerMessage.GetPropertiesIns) -_sym_db.RegisterMessage(ServerMessage.GetPropertiesIns.ConfigEntry) -_sym_db.RegisterMessage(ServerMessage.GetParametersIns) -_sym_db.RegisterMessage(ServerMessage.GetParametersIns.ConfigEntry) -_sym_db.RegisterMessage(ServerMessage.FitIns) -_sym_db.RegisterMessage(ServerMessage.FitIns.ConfigEntry) -_sym_db.RegisterMessage(ServerMessage.EvaluateIns) -_sym_db.RegisterMessage(ServerMessage.EvaluateIns.ConfigEntry) - -ClientMessage = _reflection.GeneratedProtocolMessageType('ClientMessage', (_message.Message,), { - - 'DisconnectRes' : _reflection.GeneratedProtocolMessageType('DisconnectRes', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_DISCONNECTRES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.DisconnectRes) - }) - , - - 'GetPropertiesRes' : _reflection.GeneratedProtocolMessageType('GetPropertiesRes', (_message.Message,), { - - 'PropertiesEntry' : _reflection.GeneratedProtocolMessageType('PropertiesEntry', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.GetPropertiesRes.PropertiesEntry) - }) - , - 'DESCRIPTOR' : _CLIENTMESSAGE_GETPROPERTIESRES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.GetPropertiesRes) - }) - , - - 'GetParametersRes' : _reflection.GeneratedProtocolMessageType('GetParametersRes', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_GETPARAMETERSRES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.GetParametersRes) - }) - , - - 'FitRes' : _reflection.GeneratedProtocolMessageType('FitRes', (_message.Message,), { - - 'MetricsEntry' : _reflection.GeneratedProtocolMessageType('MetricsEntry', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_FITRES_METRICSENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.FitRes.MetricsEntry) - }) - , - 'DESCRIPTOR' : _CLIENTMESSAGE_FITRES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.FitRes) - }) - , - - 'EvaluateRes' : _reflection.GeneratedProtocolMessageType('EvaluateRes', (_message.Message,), { - - 'MetricsEntry' : _reflection.GeneratedProtocolMessageType('MetricsEntry', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTMESSAGE_EVALUATERES_METRICSENTRY, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.EvaluateRes.MetricsEntry) - }) - , - 'DESCRIPTOR' : _CLIENTMESSAGE_EVALUATERES, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage.EvaluateRes) - }) - , - 'DESCRIPTOR' : _CLIENTMESSAGE, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.ClientMessage) - }) -_sym_db.RegisterMessage(ClientMessage) -_sym_db.RegisterMessage(ClientMessage.DisconnectRes) -_sym_db.RegisterMessage(ClientMessage.GetPropertiesRes) -_sym_db.RegisterMessage(ClientMessage.GetPropertiesRes.PropertiesEntry) -_sym_db.RegisterMessage(ClientMessage.GetParametersRes) -_sym_db.RegisterMessage(ClientMessage.FitRes) -_sym_db.RegisterMessage(ClientMessage.FitRes.MetricsEntry) -_sym_db.RegisterMessage(ClientMessage.EvaluateRes) -_sym_db.RegisterMessage(ClientMessage.EvaluateRes.MetricsEntry) - -Scalar = _reflection.GeneratedProtocolMessageType('Scalar', (_message.Message,), { - 'DESCRIPTOR' : _SCALAR, - '__module__' : 'flwr.proto.transport_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.Scalar) - }) -_sym_db.RegisterMessage(Scalar) - -_FLOWERSERVICE = DESCRIPTOR.services_by_name['FlowerService'] +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.transport_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY._options = None - _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY._serialized_options = b'8\001' - _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY._options = None - _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY._serialized_options = b'8\001' - _SERVERMESSAGE_FITINS_CONFIGENTRY._options = None - _SERVERMESSAGE_FITINS_CONFIGENTRY._serialized_options = b'8\001' - _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY._options = None - _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY._serialized_options = b'8\001' - _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY._options = None - _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY._serialized_options = b'8\001' - _CLIENTMESSAGE_FITRES_METRICSENTRY._options = None - _CLIENTMESSAGE_FITRES_METRICSENTRY._serialized_options = b'8\001' - _CLIENTMESSAGE_EVALUATERES_METRICSENTRY._options = None - _CLIENTMESSAGE_EVALUATERES_METRICSENTRY._serialized_options = b'8\001' - _CODE._serialized_start=2533 - _CODE._serialized_end=2674 - _REASON._serialized_start=2676 - _REASON._serialized_end=2767 - _STATUS._serialized_start=42 - _STATUS._serialized_end=99 - _PARAMETERS._serialized_start=101 - _PARAMETERS._serialized_end=151 - _SERVERMESSAGE._serialized_start=154 - _SERVERMESSAGE._serialized_end=1236 - _SERVERMESSAGE_RECONNECTINS._serialized_start=500 - _SERVERMESSAGE_RECONNECTINS._serialized_end=531 - _SERVERMESSAGE_GETPROPERTIESINS._serialized_start=534 - _SERVERMESSAGE_GETPROPERTIESINS._serialized_end=691 - _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY._serialized_start=626 - _SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY._serialized_end=691 - _SERVERMESSAGE_GETPARAMETERSINS._serialized_start=694 - _SERVERMESSAGE_GETPARAMETERSINS._serialized_end=851 - _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY._serialized_start=626 - _SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY._serialized_end=691 - _SERVERMESSAGE_FITINS._serialized_start=854 - _SERVERMESSAGE_FITINS._serialized_end=1035 - _SERVERMESSAGE_FITINS_CONFIGENTRY._serialized_start=626 - _SERVERMESSAGE_FITINS_CONFIGENTRY._serialized_end=691 - _SERVERMESSAGE_EVALUATEINS._serialized_start=1038 - _SERVERMESSAGE_EVALUATEINS._serialized_end=1229 - _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY._serialized_start=626 - _SERVERMESSAGE_EVALUATEINS_CONFIGENTRY._serialized_end=691 - _CLIENTMESSAGE._serialized_start=1239 - _CLIENTMESSAGE._serialized_end=2423 - _CLIENTMESSAGE_DISCONNECTRES._serialized_start=1587 - _CLIENTMESSAGE_DISCONNECTRES._serialized_end=1638 - _CLIENTMESSAGE_GETPROPERTIESRES._serialized_start=1641 - _CLIENTMESSAGE_GETPROPERTIESRES._serialized_end=1846 - _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY._serialized_start=1777 - _CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY._serialized_end=1846 - _CLIENTMESSAGE_GETPARAMETERSRES._serialized_start=1848 - _CLIENTMESSAGE_GETPARAMETERSRES._serialized_end=1946 - _CLIENTMESSAGE_FITRES._serialized_start=1949 - _CLIENTMESSAGE_FITRES._serialized_end=2191 - _CLIENTMESSAGE_FITRES_METRICSENTRY._serialized_start=2125 - _CLIENTMESSAGE_FITRES_METRICSENTRY._serialized_end=2191 - _CLIENTMESSAGE_EVALUATERES._serialized_start=2194 - _CLIENTMESSAGE_EVALUATERES._serialized_end=2416 - _CLIENTMESSAGE_EVALUATERES_METRICSENTRY._serialized_start=2125 - _CLIENTMESSAGE_EVALUATERES_METRICSENTRY._serialized_end=2191 - _SCALAR._serialized_start=2425 - _SCALAR._serialized_end=2530 - _FLOWERSERVICE._serialized_start=2769 - _FLOWERSERVICE._serialized_end=2852 + _globals['_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY']._options = None + _globals['_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY']._serialized_options = b'8\001' + _globals['_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY']._options = None + _globals['_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY']._serialized_options = b'8\001' + _globals['_SERVERMESSAGE_FITINS_CONFIGENTRY']._options = None + _globals['_SERVERMESSAGE_FITINS_CONFIGENTRY']._serialized_options = b'8\001' + _globals['_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY']._options = None + _globals['_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY']._serialized_options = b'8\001' + _globals['_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY']._options = None + _globals['_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY']._serialized_options = b'8\001' + _globals['_CLIENTMESSAGE_FITRES_METRICSENTRY']._options = None + _globals['_CLIENTMESSAGE_FITRES_METRICSENTRY']._serialized_options = b'8\001' + _globals['_CLIENTMESSAGE_EVALUATERES_METRICSENTRY']._options = None + _globals['_CLIENTMESSAGE_EVALUATERES_METRICSENTRY']._serialized_options = b'8\001' + _globals['_CODE']._serialized_start=2533 + _globals['_CODE']._serialized_end=2674 + _globals['_REASON']._serialized_start=2676 + _globals['_REASON']._serialized_end=2767 + _globals['_STATUS']._serialized_start=42 + _globals['_STATUS']._serialized_end=99 + _globals['_PARAMETERS']._serialized_start=101 + _globals['_PARAMETERS']._serialized_end=151 + _globals['_SERVERMESSAGE']._serialized_start=154 + _globals['_SERVERMESSAGE']._serialized_end=1236 + _globals['_SERVERMESSAGE_RECONNECTINS']._serialized_start=500 + _globals['_SERVERMESSAGE_RECONNECTINS']._serialized_end=531 + _globals['_SERVERMESSAGE_GETPROPERTIESINS']._serialized_start=534 + _globals['_SERVERMESSAGE_GETPROPERTIESINS']._serialized_end=691 + _globals['_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY']._serialized_start=626 + _globals['_SERVERMESSAGE_GETPROPERTIESINS_CONFIGENTRY']._serialized_end=691 + _globals['_SERVERMESSAGE_GETPARAMETERSINS']._serialized_start=694 + _globals['_SERVERMESSAGE_GETPARAMETERSINS']._serialized_end=851 + _globals['_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY']._serialized_start=626 + _globals['_SERVERMESSAGE_GETPARAMETERSINS_CONFIGENTRY']._serialized_end=691 + _globals['_SERVERMESSAGE_FITINS']._serialized_start=854 + _globals['_SERVERMESSAGE_FITINS']._serialized_end=1035 + _globals['_SERVERMESSAGE_FITINS_CONFIGENTRY']._serialized_start=626 + _globals['_SERVERMESSAGE_FITINS_CONFIGENTRY']._serialized_end=691 + _globals['_SERVERMESSAGE_EVALUATEINS']._serialized_start=1038 + _globals['_SERVERMESSAGE_EVALUATEINS']._serialized_end=1229 + _globals['_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY']._serialized_start=626 + _globals['_SERVERMESSAGE_EVALUATEINS_CONFIGENTRY']._serialized_end=691 + _globals['_CLIENTMESSAGE']._serialized_start=1239 + _globals['_CLIENTMESSAGE']._serialized_end=2423 + _globals['_CLIENTMESSAGE_DISCONNECTRES']._serialized_start=1587 + _globals['_CLIENTMESSAGE_DISCONNECTRES']._serialized_end=1638 + _globals['_CLIENTMESSAGE_GETPROPERTIESRES']._serialized_start=1641 + _globals['_CLIENTMESSAGE_GETPROPERTIESRES']._serialized_end=1846 + _globals['_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY']._serialized_start=1777 + _globals['_CLIENTMESSAGE_GETPROPERTIESRES_PROPERTIESENTRY']._serialized_end=1846 + _globals['_CLIENTMESSAGE_GETPARAMETERSRES']._serialized_start=1848 + _globals['_CLIENTMESSAGE_GETPARAMETERSRES']._serialized_end=1946 + _globals['_CLIENTMESSAGE_FITRES']._serialized_start=1949 + _globals['_CLIENTMESSAGE_FITRES']._serialized_end=2191 + _globals['_CLIENTMESSAGE_FITRES_METRICSENTRY']._serialized_start=2125 + _globals['_CLIENTMESSAGE_FITRES_METRICSENTRY']._serialized_end=2191 + _globals['_CLIENTMESSAGE_EVALUATERES']._serialized_start=2194 + _globals['_CLIENTMESSAGE_EVALUATERES']._serialized_end=2416 + _globals['_CLIENTMESSAGE_EVALUATERES_METRICSENTRY']._serialized_start=2125 + _globals['_CLIENTMESSAGE_EVALUATERES_METRICSENTRY']._serialized_end=2191 + _globals['_SCALAR']._serialized_start=2425 + _globals['_SCALAR']._serialized_end=2530 + _globals['_FLOWERSERVICE']._serialized_start=2769 + _globals['_FLOWERSERVICE']._serialized_end=2852 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 63c24c37a685..636207e7a859 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -38,9 +38,15 @@ TRANSPORT_TYPE_REST, ) from flwr.common.logger import log -from flwr.proto.driver_pb2_grpc import add_DriverServicer_to_server -from flwr.proto.fleet_pb2_grpc import add_FleetServicer_to_server -from flwr.proto.transport_pb2_grpc import add_FlowerServiceServicer_to_server +from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611 + add_DriverServicer_to_server, +) +from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611 + add_FleetServicer_to_server, +) +from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611 + add_FlowerServiceServicer_to_server, +) from flwr.server.client_manager import ClientManager, SimpleClientManager from flwr.server.driver.driver_servicer import DriverServicer from flwr.server.fleet.grpc_bidi.driver_client_manager import DriverClientManager diff --git a/src/py/flwr/server/driver/driver_servicer.py b/src/py/flwr/server/driver/driver_servicer.py index f96b3b1262ac..275cc8ac6a03 100644 --- a/src/py/flwr/server/driver/driver_servicer.py +++ b/src/py/flwr/server/driver/driver_servicer.py @@ -22,10 +22,10 @@ import grpc from flwr.common.logger import log -from flwr.proto import driver_pb2_grpc -from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, - CreateWorkloadResponse, +from flwr.proto import driver_pb2_grpc # pylint: disable=E0611 +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -33,8 +33,8 @@ PushTaskInsRequest, PushTaskInsResponse, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskRes +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 from flwr.server.state import State, StateFactory from flwr.server.utils.validator import validate_task_ins_or_res @@ -51,20 +51,20 @@ def GetNodes( """Get available nodes.""" log(INFO, "DriverServicer.GetNodes") state: State = self.state_factory.state() - all_ids: Set[int] = state.get_nodes(request.workload_id) + all_ids: Set[int] = state.get_nodes(request.run_id) nodes: List[Node] = [ Node(node_id=node_id, anonymous=False) for node_id in all_ids ] return GetNodesResponse(nodes=nodes) - def CreateWorkload( - self, request: CreateWorkloadRequest, context: grpc.ServicerContext - ) -> CreateWorkloadResponse: - """Create workload ID.""" - log(INFO, "DriverServicer.CreateWorkload") + def CreateRun( + self, request: CreateRunRequest, context: grpc.ServicerContext + ) -> CreateRunResponse: + """Create run ID.""" + log(INFO, "DriverServicer.CreateRun") state: State = self.state_factory.state() - workload_id = state.create_workload() - return CreateWorkloadResponse(workload_id=workload_id) + run_id = state.create_run() + return CreateRunResponse(run_id=run_id) def PushTaskIns( self, request: PushTaskInsRequest, context: grpc.ServicerContext diff --git a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py b/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py index 1f7a8e9259fc..6eccb056390a 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py +++ b/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py @@ -24,8 +24,11 @@ import grpc from iterators import TimeoutIterator -from flwr.proto import transport_pb2_grpc -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto import transport_pb2_grpc # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.client_manager import ClientManager from flwr.server.fleet.grpc_bidi.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy diff --git a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py b/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py index 64140ed274c9..b5c3f504af03 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py +++ b/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py @@ -18,7 +18,10 @@ import unittest from unittest.mock import MagicMock, call -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.fleet.grpc_bidi.flower_service_servicer import ( FlowerServiceServicer, register_client_proxy, diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py index 6ae38ea3d805..d5b4a915c609 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py @@ -20,7 +20,10 @@ from threading import Condition from typing import Iterator, Optional -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) @dataclass @@ -113,7 +116,7 @@ def _transition(self, next_status: Status) -> None: ): self._status = next_status else: - raise Exception(f"Invalid transition: {self._status} to {next_status}") + raise ValueError(f"Invalid transition: {self._status} to {next_status}") self._cv.notify_all() @@ -129,7 +132,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._raise_if_closed() if self._status != Status.AWAITING_INS_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._ins_wrapper = ins_wrapper # Write self._transition(Status.INS_WRAPPER_AVAILABLE) @@ -146,7 +149,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._transition(Status.AWAITING_INS_WRAPPER) if res_wrapper is None: - raise Exception("ResWrapper can not be None") + raise ValueError("ResWrapper can not be None") return res_wrapper @@ -170,7 +173,7 @@ def ins_wrapper_iterator(self) -> Iterator[InsWrapper]: self._transition(Status.AWAITING_RES_WRAPPER) if ins_wrapper is None: - raise Exception("InsWrapper can not be None") + raise ValueError("InsWrapper can not be None") yield ins_wrapper @@ -180,7 +183,7 @@ def set_res_wrapper(self, res_wrapper: ResWrapper) -> None: self._raise_if_closed() if self._status != Status.AWAITING_RES_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._res_wrapper = res_wrapper # Write self._transition(Status.RES_WRAPPER_AVAILABLE) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py index 18a2144072ed..6527c45d7d6c 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py @@ -19,7 +19,10 @@ from threading import Thread from typing import List, Union -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.fleet.grpc_bidi.grpc_bridge import ( GrpcBridge, GrpcBridgeClosed, @@ -70,6 +73,7 @@ def test_workflow_successful() -> None: _ = next(ins_wrapper_iterator) bridge.set_res_wrapper(ResWrapper(client_message=ClientMessage())) except Exception as exception: + # pylint: disable-next=broad-exception-raised raise Exception from exception # Wait until worker_thread is finished diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py index b9bc7330db31..46185896561e 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy.py @@ -19,7 +19,10 @@ from flwr import common from flwr.common import serde -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.client_proxy import ClientProxy from flwr.server.fleet.grpc_bidi.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py index 329f29b3f616..1a417ae433d5 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_client_proxy_test.py @@ -22,7 +22,11 @@ import flwr from flwr.common.typing import Config, GetParametersIns -from flwr.proto.transport_pb2 import ClientMessage, Parameters, Scalar +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + Parameters, + Scalar, +) from flwr.server.fleet.grpc_bidi.grpc_bridge import ResWrapper from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_server.py index fc81e8eb8f4c..e05df88dcd12 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_server.py @@ -24,7 +24,9 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.logger import log -from flwr.proto.transport_pb2_grpc import add_FlowerServiceServicer_to_server +from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611 + add_FlowerServiceServicer_to_server, +) from flwr.server.client_manager import ClientManager from flwr.server.driver.driver_servicer import DriverServicer from flwr.server.fleet.grpc_bidi.flower_service_servicer import FlowerServiceServicer diff --git a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py index 1c737d31c7fc..5843934b64a4 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py +++ b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py @@ -23,9 +23,12 @@ from flwr.client.message_handler.task_handler import configure_task_res from flwr.common import EvaluateRes, FitRes, GetParametersRes, GetPropertiesRes, serde from flwr.common.logger import log -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.client_proxy import ClientProxy from flwr.server.state import State, StateFactory @@ -166,6 +169,6 @@ def _call_client_proxy( evaluate_res_proto = serde.evaluate_res_to_proto(res=evaluate_res) return ClientMessage(evaluate_res=evaluate_res_proto) - raise Exception( + raise ValueError( "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" ) diff --git a/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py index 022470cffe8a..b12f365e898c 100644 --- a/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py @@ -20,8 +20,8 @@ import grpc from flwr.common.logger import log -from flwr.proto import fleet_pb2_grpc -from flwr.proto.fleet_pb2 import ( +from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611 +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, diff --git a/src/py/flwr/server/fleet/message_handler/message_handler.py b/src/py/flwr/server/fleet/message_handler/message_handler.py index 71876386f059..8d451c896ed9 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/fleet/message_handler/message_handler.py @@ -18,7 +18,7 @@ from typing import List, Optional from uuid import UUID -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, @@ -29,8 +29,8 @@ PushTaskResResponse, Reconnect, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.state import State diff --git a/src/py/flwr/server/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/fleet/message_handler/message_handler_test.py index 25fd822492f2..c135f6fb7b61 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler_test.py +++ b/src/py/flwr/server/fleet/message_handler/message_handler_test.py @@ -17,14 +17,14 @@ from unittest.mock import MagicMock -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, ) -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskRes +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from .message_handler import create_node, delete_node, pull_task_ins, push_task_res @@ -109,7 +109,7 @@ def test_push_task_res() -> None: TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task(), ), ], diff --git a/src/py/flwr/server/fleet/rest_rere/rest_api.py b/src/py/flwr/server/fleet/rest_rere/rest_api.py index cd1e47f24f00..b815558cb099 100644 --- a/src/py/flwr/server/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/fleet/rest_rere/rest_api.py @@ -18,7 +18,7 @@ import sys from flwr.common.constant import MISSING_EXTRA_REST -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 63ec1021ff5c..9b5c03aeeaf9 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -47,14 +47,14 @@ class SuccessClient(ClientProxy): def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float] ) -> GetPropertiesRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise an error because this method is not expected to be called.""" + raise NotImplementedError() def get_parameters( self, ins: GetParametersIns, timeout: Optional[float] ) -> GetParametersRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise a error because this method is not expected to be called.""" + raise NotImplementedError() def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: """Simulate fit by returning a success FitRes with simple set of weights.""" @@ -87,26 +87,26 @@ class FailingClient(ClientProxy): def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float] ) -> GetPropertiesRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def get_parameters( self, ins: GetParametersIns, timeout: Optional[float] ) -> GetParametersRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def test_fit_clients() -> None: diff --git a/src/py/flwr/server/state/in_memory_state.py b/src/py/flwr/server/state/in_memory_state.py index 384839b7461f..f21845fcb909 100644 --- a/src/py/flwr/server/state/in_memory_state.py +++ b/src/py/flwr/server/state/in_memory_state.py @@ -22,7 +22,7 @@ from uuid import UUID, uuid4 from flwr.common import log, now -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.state.state import State from flwr.server.utils import validate_task_ins_or_res @@ -32,7 +32,7 @@ class InMemoryState(State): def __init__(self) -> None: self.node_ids: Set[int] = set() - self.workload_ids: Set[int] = set() + self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} @@ -43,9 +43,9 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: if any(errors): log(ERROR, errors) return None - # Validate workload_id - if task_ins.workload_id not in self.workload_ids: - log(ERROR, "`workload_id` is invalid") + # Validate run_id + if task_ins.run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") return None # Create task_id, created_at and ttl @@ -104,9 +104,9 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, errors) return None - # Validate workload_id - if task_res.workload_id not in self.workload_ids: - log(ERROR, "`workload_id` is invalid") + # Validate run_id + if task_res.run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") return None # Create task_id, created_at and ttl @@ -199,25 +199,25 @@ def delete_node(self, node_id: int) -> None: raise ValueError(f"Node {node_id} not found") self.node_ids.remove(node_id) - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Return all available client nodes. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ - if workload_id not in self.workload_ids: + if run_id not in self.run_ids: return set() return self.node_ids - def create_workload(self) -> int: - """Create one workload.""" - # Sample a random int64 as workload_id - workload_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + def create_run(self) -> int: + """Create one run.""" + # Sample a random int64 as run_id + run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) - if workload_id not in self.workload_ids: - self.workload_ids.add(workload_id) - return workload_id - log(ERROR, "Unexpected workload creation failure.") + if run_id not in self.run_ids: + self.run_ids.add(run_id) + return run_id + log(ERROR, "Unexpected run creation failure.") return 0 diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index f3ff60f370e9..538ecb84491f 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -24,9 +24,12 @@ from uuid import UUID, uuid4 from flwr.common import log, now -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.utils.validator import validate_task_ins_or_res from .state import State @@ -37,9 +40,9 @@ ); """ -SQL_CREATE_TABLE_WORKLOAD = """ -CREATE TABLE IF NOT EXISTS workload( - workload_id INTEGER UNIQUE +SQL_CREATE_TABLE_RUN = """ +CREATE TABLE IF NOT EXISTS run( + run_id INTEGER UNIQUE ); """ @@ -47,7 +50,7 @@ CREATE TABLE IF NOT EXISTS task_ins( task_id TEXT UNIQUE, group_id TEXT, - workload_id INTEGER, + run_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -58,7 +61,7 @@ ancestry TEXT, legacy_server_message BLOB, legacy_client_message BLOB, - FOREIGN KEY(workload_id) REFERENCES workload(workload_id) + FOREIGN KEY(run_id) REFERENCES run(run_id) ); """ @@ -67,7 +70,7 @@ CREATE TABLE IF NOT EXISTS task_res( task_id TEXT UNIQUE, group_id TEXT, - workload_id INTEGER, + run_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -78,7 +81,7 @@ ancestry TEXT, legacy_server_message BLOB, legacy_client_message BLOB, - FOREIGN KEY(workload_id) REFERENCES workload(workload_id) + FOREIGN KEY(run_id) REFERENCES run(run_id) ); """ @@ -119,7 +122,7 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur = self.conn.cursor() # Create each table if not exists queries - cur.execute(SQL_CREATE_TABLE_WORKLOAD) + cur.execute(SQL_CREATE_TABLE_RUN) cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) @@ -134,7 +137,7 @@ def query( ) -> List[Dict[str, Any]]: """Execute a SQL query.""" if self.conn is None: - raise Exception("State is not initialized.") + raise AttributeError("State is not initialized.") if data is None: data = [] @@ -198,12 +201,12 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" - # Only invalid workload_id can trigger IntegrityError. + # Only invalid run_id can trigger IntegrityError. # This may need to be changed in the future version with more integrity checks. try: self.query(query, data) except sqlite3.IntegrityError: - log(ERROR, "`workload` is invalid") + log(ERROR, "`run` is invalid") return None return task_id @@ -333,12 +336,12 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" - # Only invalid workload_id can trigger IntegrityError. + # Only invalid run_id can trigger IntegrityError. # This may need to be changed in the future version with more integrity checks. try: self.query(query, data) except sqlite3.IntegrityError: - log(ERROR, "`workload` is invalid") + log(ERROR, "`run` is invalid") return None return task_id @@ -459,7 +462,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: """ if self.conn is None: - raise Exception("State not intitialized") + raise AttributeError("State not intitialized") with self.conn: self.conn.execute(query_1, data) @@ -485,17 +488,17 @@ def delete_node(self, node_id: int) -> None: query = "DELETE FROM node WHERE node_id = :node_id;" self.query(query, {"node_id": node_id}) - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ - # Validate workload ID - query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;" - if self.query(query, (workload_id,))[0]["COUNT(*)"] == 0: + # Validate run ID + query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" + if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: return set() # Get nodes @@ -504,19 +507,19 @@ def get_nodes(self, workload_id: int) -> Set[int]: result: Set[int] = {row["node_id"] for row in rows} return result - def create_workload(self) -> int: - """Create one workload and store it in state.""" - # Sample a random int64 as workload_id - workload_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + def create_run(self) -> int: + """Create one run and store it in state.""" + # Sample a random int64 as run_id + run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) # Check conflicts - query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;" - # If workload_id does not exist - if self.query(query, (workload_id,))[0]["COUNT(*)"] == 0: - query = "INSERT INTO workload VALUES(:workload_id);" - self.query(query, {"workload_id": workload_id}) - return workload_id - log(ERROR, "Unexpected workload creation failure.") + query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" + # If run_id does not exist + if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: + query = "INSERT INTO run VALUES(:run_id);" + self.query(query, {"run_id": run_id}) + return run_id + log(ERROR, "Unexpected run creation failure.") return 0 @@ -537,7 +540,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: result = { "task_id": task_msg.task_id, "group_id": task_msg.group_id, - "workload_id": task_msg.workload_id, + "run_id": task_msg.run_id, "producer_anonymous": task_msg.task.producer.anonymous, "producer_node_id": task_msg.task.producer.node_id, "consumer_anonymous": task_msg.task.consumer.anonymous, @@ -559,7 +562,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: result = { "task_id": task_msg.task_id, "group_id": task_msg.group_id, - "workload_id": task_msg.workload_id, + "run_id": task_msg.run_id, "producer_anonymous": task_msg.task.producer.anonymous, "producer_node_id": task_msg.task.producer.node_id, "consumer_anonymous": task_msg.task.consumer.anonymous, @@ -584,7 +587,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: result = TaskIns( task_id=task_dict["task_id"], group_id=task_dict["group_id"], - workload_id=task_dict["workload_id"], + run_id=task_dict["run_id"], task=Task( producer=Node( node_id=task_dict["producer_node_id"], @@ -612,7 +615,7 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: result = TaskRes( task_id=task_dict["task_id"], group_id=task_dict["group_id"], - workload_id=task_dict["workload_id"], + run_id=task_dict["run_id"], task=Task( producer=Node( node_id=task_dict["producer_node_id"], diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/state/sqlite_state_test.py index da8fead1438e..a3f899386011 100644 --- a/src/py/flwr/server/state/sqlite_state_test.py +++ b/src/py/flwr/server/state/sqlite_state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Test for utility functions.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import unittest @@ -27,11 +27,11 @@ class SqliteStateTest(unittest.TestCase): def test_ins_res_to_dict(self) -> None: """Check if all required keys are included in return value.""" # Prepare - ins_res = create_task_ins(consumer_node_id=1, anonymous=True, workload_id=0) + ins_res = create_task_ins(consumer_node_id=1, anonymous=True, run_id=0) expected_keys = [ "task_id", "group_id", - "workload_id", + "run_id", "producer_anonymous", "producer_node_id", "consumer_anonymous", diff --git a/src/py/flwr/server/state/state.py b/src/py/flwr/server/state/state.py index fd8bbc8e8e25..9337ae6d8624 100644 --- a/src/py/flwr/server/state/state.py +++ b/src/py/flwr/server/state/state.py @@ -19,7 +19,7 @@ from typing import List, Optional, Set from uuid import UUID -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 class State(abc.ABC): @@ -43,7 +43,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: If `task_ins.task.consumer.anonymous` is `False`, then `task_ins.task.consumer.node_id` MUST be set (not 0) - If `task_ins.workload_id` is invalid, then + If `task_ins.run_id` is invalid, then storing the `task_ins` MUST fail. """ @@ -92,7 +92,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: If `task_res.task.consumer.anonymous` is `False`, then `task_res.task.consumer.node_id` MUST be set (not 0) - If `task_res.workload_id` is invalid, then + If `task_res.run_id` is invalid, then storing the `task_res` MUST fail. """ @@ -140,15 +140,15 @@ def delete_node(self, node_id: int) -> None: """Remove `node_id` from state.""" @abc.abstractmethod - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ @abc.abstractmethod - def create_workload(self) -> int: - """Create one workload.""" + def create_run(self) -> int: + """Create one run.""" diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 59299451c3d8..7f9094625765 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Tests all state implemenations have to conform to.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import tempfile import unittest @@ -22,9 +22,12 @@ from typing import List from uuid import uuid4 -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from flwr.server.state import InMemoryState, SqliteState, State @@ -66,9 +69,9 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) assert task_ins.task.created_at == "" # pylint: disable=no-member @@ -108,15 +111,15 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins_0 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) task_ins_1 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) task_ins_2 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) # Insert three TaskIns @@ -136,7 +139,7 @@ def test_store_and_delete_tasks(self) -> None: producer_node_id=100, anonymous=False, ancestry=[str(task_id_0)], - workload_id=workload_id, + run_id=run_id, ) _ = state.store_task_res(task_res=task_res_0) @@ -147,7 +150,7 @@ def test_store_and_delete_tasks(self) -> None: producer_node_id=100, anonymous=False, ancestry=[str(task_id_1)], - workload_id=workload_id, + run_id=run_id, ) _ = state.store_task_res(task_res=task_res_1) @@ -182,10 +185,8 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute task_ins_uuid = state.store_task_ins(task_ins) @@ -199,10 +200,8 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -215,10 +214,8 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -231,10 +228,8 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute task_ins_uuid = state.store_task_ins(task_ins) @@ -250,10 +245,8 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -278,13 +271,11 @@ def test_get_task_ins_limit_throws_for_limit_zero(self) -> None: with self.assertRaises(AssertionError): state.get_task_ins(node_id=1, limit=0) - def test_task_ins_store_invalid_workload_id_and_fail(self) -> None: - """Store TaskIns with invalid workload_id and fail.""" + def test_task_ins_store_invalid_run_id_and_fail(self) -> None: + """Store TaskIns with invalid run_id and fail.""" # Prepare state: State = self.state_factory() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=61016 - ) + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=61016) # Execute task_id = state.store_task_ins(task_ins) @@ -297,13 +288,13 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, anonymous=True, ancestry=[str(task_ins_id)], - workload_id=workload_id, + run_id=run_id, ) # Execute @@ -318,10 +309,10 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() # Execute - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert assert len(retrieved_node_ids) == 0 @@ -330,13 +321,13 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() node_ids = [] # Execute for _ in range(10): node_ids.append(state.create_node()) - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert for i in retrieved_node_ids: @@ -346,26 +337,26 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() node_id = state.create_node() # Execute state.delete_node(node_id) - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert assert len(retrieved_node_ids) == 0 - def test_get_nodes_invalid_workload_id(self) -> None: - """Test retrieving all node_ids with invalid workload_id.""" + def test_get_nodes_invalid_run_id(self) -> None: + """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_workload() - invalid_workload_id = 61016 + state.create_run() + invalid_run_id = 61016 state.create_node() # Execute - retrieved_node_ids = state.get_nodes(invalid_workload_id) + retrieved_node_ids = state.get_nodes(invalid_run_id) # Assert assert len(retrieved_node_ids) == 0 @@ -374,13 +365,9 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_0 = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) - task_1 = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Store two tasks state.store_task_ins(task_0) @@ -396,12 +383,12 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_0 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], workload_id=workload_id + producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) task_1 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], workload_id=workload_id + producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) # Store two tasks @@ -418,7 +405,7 @@ def test_num_task_res(self) -> None: def create_task_ins( consumer_node_id: int, anonymous: bool, - workload_id: int, + run_id: int, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -429,7 +416,7 @@ def create_task_ins( task = TaskIns( task_id="", group_id="", - workload_id=workload_id, + run_id=run_id, task=Task( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), @@ -446,13 +433,13 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - workload_id: int, + run_id: int, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( task_id="", group_id="", - workload_id=workload_id, + run_id=run_id, task=Task( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 0772aa1ff13a..1750a7522379 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -29,6 +29,7 @@ from .fedprox import FedProx as FedProx from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg from .fedxgb_bagging import FedXgbBagging as FedXgbBagging +from .fedxgb_cyclic import FedXgbCyclic as FedXgbCyclic from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg from .fedyogi import FedYogi as FedYogi from .krum import Krum as Krum @@ -42,6 +43,7 @@ "FedAvg", "FedXgbNnAvg", "FedXgbBagging", + "FedXgbCyclic", "FedAvgAndroid", "FedAvgM", "FedOpt", diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index 63926f2eaa51..c668b55eebe6 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -20,13 +20,14 @@ import numpy as np -from flwr.common import NDArray, NDArrays +from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays +from flwr.server.client_proxy import ClientProxy def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute weighted average.""" # Calculate the total number of examples used during training - num_examples_total = sum([num_examples for _, num_examples in results]) + num_examples_total = sum(num_examples for (_, num_examples) in results) # Create a list of weights, each multiplied by the related number of examples weighted_weights = [ @@ -41,6 +42,31 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: return weights_prime +def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: + """Compute in-place weighted average.""" + # Count total examples + num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results) + + # Compute scaling factors for each result + scaling_factors = [ + fit_res.num_examples / num_examples_total for _, fit_res in results + ] + + # Let's do in-place aggregation + # Get first result, then add up each other + params = [ + scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters) + ] + for i, (_, fit_res) in enumerate(results[1:]): + res = ( + scaling_factors[i + 1] * x + for x in parameters_to_ndarrays(fit_res.parameters) + ) + params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)] + + return params + + def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute median.""" # Create a list of weights and ignore the number of examples @@ -69,9 +95,9 @@ def aggregate_krum( # For each client, take the n-f-2 closest parameters vectors num_closest = max(1, len(weights) - num_malicious - 2) closest_indices = [] - for i, _ in enumerate(distance_matrix): + for distance in distance_matrix: closest_indices.append( - np.argsort(distance_matrix[i])[1 : num_closest + 1].tolist() # noqa: E203 + np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203 ) # Compute the score for each client, that is the sum of the distances @@ -176,7 +202,7 @@ def aggregate_bulyan( def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: """Aggregate evaluation results obtained from multiple clients.""" - num_total_evaluation_examples = sum([num_examples for num_examples, _ in results]) + num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results) weighted_losses = [num_examples * loss for num_examples, loss in results] return sum(weighted_losses) / num_total_evaluation_examples @@ -207,9 +233,9 @@ def _compute_distances(weights: List[NDArrays]) -> NDArray: """ flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights]) distance_matrix = np.zeros((len(weights), len(weights))) - for i, _ in enumerate(flat_w): - for j, _ in enumerate(flat_w): - delta = flat_w[i] - flat_w[j] + for i, flat_w_i in enumerate(flat_w): + for j, flat_w_j in enumerate(flat_w): + delta = flat_w_i - flat_w_j norm = np.linalg.norm(delta) distance_matrix[i, j] = norm**2 return distance_matrix diff --git a/src/py/flwr/server/strategy/bulyan.py b/src/py/flwr/server/strategy/bulyan.py index 0243f4e6546f..1e4f97530ab7 100644 --- a/src/py/flwr/server/strategy/bulyan.py +++ b/src/py/flwr/server/strategy/bulyan.py @@ -38,10 +38,43 @@ # flake8: noqa: E501 +# pylint: disable=line-too-long class Bulyan(FedAvg): - """Bulyan strategy implementation.""" - - # pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long, too-many-locals + """Bulyan strategy. + + Implementation based on https://arxiv.org/abs/1802.07927. + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + num_malicious_clients : int, optional + Number of malicious clients in the system. Defaults to 0. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters, optional + Initial global model parameters. + first_aggregation_rule: Callable + Byzantine resilient aggregation rule that is used as the first step of the Bulyan (e.g., Krum) + **aggregation_rule_kwargs: Any + arguments to the first_aggregation rule + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals def __init__( self, *, @@ -66,39 +99,6 @@ def __init__( first_aggregation_rule: Callable = aggregate_krum, # type: ignore **aggregation_rule_kwargs: Any, ) -> None: - """Bulyan strategy. - - Implementation based on https://arxiv.org/abs/1802.07927. - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. Defaults to 1.0. - fraction_evaluate : float, optional - Fraction of clients used during validation. Defaults to 1.0. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - num_malicious_clients : int, optional - Number of malicious clients in the system. Defaults to 0. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters, optional - Initial global model parameters. - first_aggregation_rule: Callable - Byzantine resilient aggregation rule that is used as the first step of the Bulyan (e.g., Krum) - **aggregation_rule_kwargs: Any - arguments to the first_aggregation rule - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index 3269735e9d73..8b3278cc9ba0 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -91,7 +91,7 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: norm_bit_set_count = 0 for client_proxy, fit_res in results: if "dpfedavg_norm_bit" not in fit_res.metrics: - raise Exception( + raise KeyError( f"Indicator bit not returned by client with id {client_proxy.cid}." ) if fit_res.metrics["dpfedavg_norm_bit"]: diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index 0154cfd79fc5..f2f1c206f3de 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -46,11 +46,11 @@ def __init__( self.num_sampled_clients = num_sampled_clients if clip_norm <= 0: - raise Exception("The clipping threshold should be a positive value.") + raise ValueError("The clipping threshold should be a positive value.") self.clip_norm = clip_norm if noise_multiplier < 0: - raise Exception("The noise multiplier should be a non-negative value.") + raise ValueError("The noise multiplier should be a non-negative value.") self.noise_multiplier = noise_multiplier self.server_side_noising = server_side_noising diff --git a/src/py/flwr/server/strategy/fedadagrad.py b/src/py/flwr/server/strategy/fedadagrad.py index 085362891d94..4a8f52d98e18 100644 --- a/src/py/flwr/server/strategy/fedadagrad.py +++ b/src/py/flwr/server/strategy/fedadagrad.py @@ -38,13 +38,47 @@ from .fedopt import FedOpt +# pylint: disable=line-too-long class FedAdagrad(FedOpt): """FedAdagrad strategy - Adaptive Federated Optimization using Adagrad. - Paper: https://arxiv.org/abs/2003.00295 + Implementation based on https://arxiv.org/abs/2003.00295v5 + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters + Initial global model parameters. + eta : float, optional + Server-side learning rate. Defaults to 1e-1. + eta_l : float, optional + Client-side learning rate. Defaults to 1e-1. + tau : float, optional + Controls the algorithm's degree of adaptability. Defaults to 1e-9. """ - # pylint: disable=too-many-arguments,too-many-locals,too-many-instance-attributes, line-too-long + # pylint: disable=too-many-arguments,too-many-locals,too-many-instance-attributes def __init__( self, *, @@ -69,43 +103,6 @@ def __init__( eta_l: float = 1e-1, tau: float = 1e-9, ) -> None: - """Federated learning strategy using Adagrad on server-side. - - Implementation based on https://arxiv.org/abs/2003.00295v5 - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. Defaults to 1.0. - fraction_evaluate : float, optional - Fraction of clients used during validation. Defaults to 1.0. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters - Initial global model parameters. - eta : float, optional - Server-side learning rate. Defaults to 1e-1. - eta_l : float, optional - Client-side learning rate. Defaults to 1e-1. - tau : float, optional - Controls the algorithm's degree of adaptability. Defaults to 1e-9. - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/fedadam.py b/src/py/flwr/server/strategy/fedadam.py index ca6229029376..8a47cf0dd8ac 100644 --- a/src/py/flwr/server/strategy/fedadam.py +++ b/src/py/flwr/server/strategy/fedadam.py @@ -38,13 +38,51 @@ from .fedopt import FedOpt +# pylint: disable=line-too-long class FedAdam(FedOpt): """FedAdam - Adaptive Federated Optimization using Adam. - Paper: https://arxiv.org/abs/2003.00295 + Implementation based on https://arxiv.org/abs/2003.00295v5 + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters + Initial global model parameters. + fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + eta : float, optional + Server-side learning rate. Defaults to 1e-1. + eta_l : float, optional + Client-side learning rate. Defaults to 1e-1. + beta_1 : float, optional + Momentum parameter. Defaults to 0.9. + beta_2 : float, optional + Second moment parameter. Defaults to 0.99. + tau : float, optional + Controls the algorithm's degree of adaptability. Defaults to 1e-9. """ - # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals, line-too-long + # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals def __init__( self, *, @@ -71,47 +109,6 @@ def __init__( beta_2: float = 0.99, tau: float = 1e-9, ) -> None: - """Federated learning strategy using Adagrad on server-side. - - Implementation based on https://arxiv.org/abs/2003.00295v5 - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. Defaults to 1.0. - fraction_evaluate : float, optional - Fraction of clients used during validation. Defaults to 1.0. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters - Initial global model parameters. - fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - eta : float, optional - Server-side learning rate. Defaults to 1e-1. - eta_l : float, optional - Client-side learning rate. Defaults to 1e-1. - beta_1 : float, optional - Momentum parameter. Defaults to 0.9. - beta_2 : float, optional - Second moment parameter. Defaults to 0.99. - tau : float, optional - Controls the algorithm's degree of adaptability. Defaults to 1e-9. - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/fedavg.py b/src/py/flwr/server/strategy/fedavg.py index 86afd31e2bfb..e4b126823fb6 100644 --- a/src/py/flwr/server/strategy/fedavg.py +++ b/src/py/flwr/server/strategy/fedavg.py @@ -37,7 +37,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy -from .aggregate import aggregate, weighted_loss_avg +from .aggregate import aggregate, aggregate_inplace, weighted_loss_avg from .strategy import Strategy WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ @@ -48,8 +48,43 @@ """ +# pylint: disable=line-too-long class FedAvg(Strategy): - """Configurable FedAvg strategy implementation.""" + """Federated Averaging strategy. + + Implementation based on https://arxiv.org/abs/1602.05629 + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. In case `min_fit_clients` + is larger than `fraction_fit * available_clients`, `min_fit_clients` + will still be sampled. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. In case `min_evaluate_clients` + is larger than `fraction_evaluate * available_clients`, + `min_evaluate_clients` will still be sampled. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters, optional + Initial global model parameters. + fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + """ # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long def __init__( @@ -72,42 +107,8 @@ def __init__( initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + inplace: bool = True, ) -> None: - """Federated Averaging strategy. - - Implementation based on https://arxiv.org/abs/1602.05629 - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. In case `min_fit_clients` - is larger than `fraction_fit * available_clients`, `min_fit_clients` - will still be sampled. Defaults to 1.0. - fraction_evaluate : float, optional - Fraction of clients used during validation. In case `min_evaluate_clients` - is larger than `fraction_evaluate * available_clients`, - `min_evaluate_clients` will still be sampled. Defaults to 1.0. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters, optional - Initial global model parameters. - fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - """ super().__init__() if ( @@ -128,6 +129,7 @@ def __init__( self.initial_parameters = initial_parameters self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn + self.inplace = inplace def __repr__(self) -> str: """Compute a string representation of the strategy.""" @@ -226,12 +228,18 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) - for _, fit_res in results - ] - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + if self.inplace: + # Does in-place weighted average of results + aggregated_ndarrays = aggregate_inplace(results) + else: + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + aggregated_ndarrays = aggregate(weights_results) + + parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py index 377397678e38..6678b7ced114 100644 --- a/src/py/flwr/server/strategy/fedavg_android.py +++ b/src/py/flwr/server/strategy/fedavg_android.py @@ -39,10 +39,38 @@ from .strategy import Strategy +# pylint: disable=line-too-long class FedAvgAndroid(Strategy): - """Configurable FedAvg strategy implementation.""" - - # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + """Federated Averaging strategy. + + Implementation based on https://arxiv.org/abs/1602.05629 + + Parameters + ---------- + fraction_fit : Optional[float] + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : Optional[float] + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : Optional[int] + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : Optional[int] + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : Optional[int] + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Optional[Callable[[int], Dict[str, Scalar]]] + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Optional[Callable[[int], Dict[str, Scalar]]] + Function used to configure validation. Defaults to None. + accept_failures : Optional[bool] + Whether or not accept rounds + containing failures. Defaults to True. + initial_parameters : Optional[Parameters] + Initial global model parameters. + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( self, *, @@ -62,34 +90,6 @@ def __init__( accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, ) -> None: - """Federated Averaging strategy. - - Implementation based on https://arxiv.org/abs/1602.05629 - - Parameters - ---------- - fraction_fit : Optional[float] - Fraction of clients used during training. Defaults to 0.1. - fraction_evaluate : Optional[float] - Fraction of clients used during validation. Defaults to 0.1. - min_fit_clients : Optional[int] - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : Optional[int] - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : Optional[int] - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Optional[Callable[[int], Dict[str, Scalar]]] - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Optional[Callable[[int], Dict[str, Scalar]]] - Function used to configure validation. Defaults to None. - accept_failures : Optional[bool] - Whether or not accept rounds - containing failures. Defaults to True. - initial_parameters : Optional[Parameters] - Initial global model parameters. - """ super().__init__() self.min_fit_clients = min_fit_clients self.min_evaluate_clients = min_evaluate_clients @@ -234,12 +234,10 @@ def parameters_to_ndarrays(self, parameters: Parameters) -> NDArrays: """Convert parameters object to NumPy weights.""" return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors] - # pylint: disable=R0201 def ndarray_to_bytes(self, ndarray: NDArray) -> bytes: """Serialize NumPy array to bytes.""" return ndarray.tobytes() - # pylint: disable=R0201 def bytes_to_ndarray(self, tensor: bytes) -> NDArray: """Deserialize NumPy array from bytes.""" ndarray_deserialized = np.frombuffer(tensor, dtype=np.float32) diff --git a/src/py/flwr/server/strategy/fedavg_test.py b/src/py/flwr/server/strategy/fedavg_test.py index 947736f4a571..e62eaa5c5832 100644 --- a/src/py/flwr/server/strategy/fedavg_test.py +++ b/src/py/flwr/server/strategy/fedavg_test.py @@ -15,6 +15,16 @@ """FedAvg tests.""" +from typing import List, Tuple, Union +from unittest.mock import MagicMock + +import numpy as np +from numpy.testing import assert_allclose + +from flwr.common import Code, FitRes, Status, parameters_to_ndarrays +from flwr.common.parameter import ndarrays_to_parameters +from flwr.server.client_proxy import ClientProxy + from .fedavg import FedAvg @@ -120,3 +130,51 @@ def test_fedavg_num_evaluation_clients_minimum() -> None: # Assert assert expected == actual + + +def test_inplace_aggregate_fit_equivalence() -> None: + """Test aggregate_fit equivalence between FedAvg and its inplace version.""" + # Prepare + weights0_0 = np.random.randn(100, 64) + weights0_1 = np.random.randn(314, 628, 3) + weights1_0 = np.random.randn(100, 64) + weights1_1 = np.random.randn(314, 628, 3) + + results: List[Tuple[ClientProxy, FitRes]] = [ + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=ndarrays_to_parameters([weights0_0, weights0_1]), + num_examples=1, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=ndarrays_to_parameters([weights1_0, weights1_1]), + num_examples=5, + metrics={}, + ), + ), + ] + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + + fedavg_reference = FedAvg(inplace=False) + fedavg_inplace = FedAvg() + + # Execute + reference, _ = fedavg_reference.aggregate_fit(1, results, failures) + assert reference + inplace, _ = fedavg_inplace.aggregate_fit(1, results, failures) + assert inplace + + # Convert to NumPy to check similarity + reference_np = parameters_to_ndarrays(reference) + inplace_np = parameters_to_ndarrays(inplace) + + # Assert + for ref, inp in zip(reference_np, inplace_np): + assert_allclose(ref, inp) diff --git a/src/py/flwr/server/strategy/fedavgm.py b/src/py/flwr/server/strategy/fedavgm.py index 37cccd01479c..fb9261abe89d 100644 --- a/src/py/flwr/server/strategy/fedavgm.py +++ b/src/py/flwr/server/strategy/fedavgm.py @@ -38,8 +38,40 @@ from .fedavg import FedAvg +# pylint: disable=line-too-long class FedAvgM(FedAvg): - """Configurable FedAvg with Momentum strategy implementation.""" + """Federated Averaging with Momentum strategy. + + Implementation based on https://arxiv.org/abs/1909.06335 + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters, optional + Initial global model parameters. + server_learning_rate: float + Server-side learning rate used in server-side optimization. + Defaults to 1.0. + server_momentum: float + Server-side momentum factor used for FedAvgM. Defaults to 0.0. + """ # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long def __init__( @@ -65,38 +97,6 @@ def __init__( server_learning_rate: float = 1.0, server_momentum: float = 0.0, ) -> None: - """Federated Averaging with Momentum strategy. - - Implementation based on https://arxiv.org/pdf/1909.06335.pdf - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. Defaults to 0.1. - fraction_evaluate : float, optional - Fraction of clients used during validation. Defaults to 0.1. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters, optional - Initial global model parameters. - server_learning_rate: float - Server-side learning rate used in server-side optimization. - Defaults to 1.0. - server_momentum: float - Server-side momentum factor used for FedAvgM. Defaults to 0.0. - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py index 7a5bf1425b44..17e979d92beb 100644 --- a/src/py/flwr/server/strategy/fedmedian.py +++ b/src/py/flwr/server/strategy/fedmedian.py @@ -36,7 +36,7 @@ class FedMedian(FedAvg): - """Configurable FedAvg with Momentum strategy implementation.""" + """Configurable FedMedian strategy implementation.""" def __repr__(self) -> str: """Compute a string representation of the strategy.""" diff --git a/src/py/flwr/server/strategy/fedopt.py b/src/py/flwr/server/strategy/fedopt.py index 78dd92061c07..be5f260d96fa 100644 --- a/src/py/flwr/server/strategy/fedopt.py +++ b/src/py/flwr/server/strategy/fedopt.py @@ -31,8 +31,49 @@ from .fedavg import FedAvg +# pylint: disable=line-too-long class FedOpt(FedAvg): - """Configurable FedAdagrad strategy implementation.""" + """Federated Optim strategy. + + Implementation based on https://arxiv.org/abs/2003.00295v5 + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters + Initial global model parameters. + fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + eta : float, optional + Server-side learning rate. Defaults to 1e-1. + eta_l : float, optional + Client-side learning rate. Defaults to 1e-1. + beta_1 : float, optional + Momentum parameter. Defaults to 0.0. + beta_2 : float, optional + Second moment parameter. Defaults to 0.0. + tau : float, optional + Controls the algorithm's degree of adaptability. Defaults to 1e-9. + """ # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals, line-too-long def __init__( @@ -61,47 +102,6 @@ def __init__( beta_2: float = 0.0, tau: float = 1e-9, ) -> None: - """Federated Optim strategy interface. - - Implementation based on https://arxiv.org/abs/2003.00295v5 - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. Defaults to 1.0. - fraction_evaluate : float, optional - Fraction of clients used during validation. Defaults to 1.0. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters - Initial global model parameters. - fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - eta : float, optional - Server-side learning rate. Defaults to 1e-1. - eta_l : float, optional - Client-side learning rate. Defaults to 1e-1. - beta_1 : float, optional - Momentum parameter. Defaults to 0.0. - beta_2 : float, optional - Second moment parameter. Defaults to 0.0. - tau : float, optional - Controls the algorithm's degree of adaptability. Defaults to 1e-9. - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/fedprox.py b/src/py/flwr/server/strategy/fedprox.py index b2e3db7c31f4..d20f578b193d 100644 --- a/src/py/flwr/server/strategy/fedprox.py +++ b/src/py/flwr/server/strategy/fedprox.py @@ -27,10 +27,82 @@ from .fedavg import FedAvg +# pylint: disable=line-too-long class FedProx(FedAvg): - """Configurable FedProx strategy implementation.""" - - # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + r"""Federated Optimization strategy. + + Implementation based on https://arxiv.org/abs/1812.06127 + + The strategy in itself will not be different than FedAvg, the client needs to + be adjusted. + A proximal term needs to be added to the loss function during the training: + + .. math:: + \\frac{\\mu}{2} || w - w^t ||^2 + + Where $w^t$ are the global parameters and $w$ are the local weights the function + will be optimized with. + + In PyTorch, for example, the loss would go from: + + .. code:: python + + loss = criterion(net(inputs), labels) + + To: + + .. code:: python + + for local_weights, global_weights in zip(net.parameters(), global_params): + proximal_term += (local_weights - global_weights).norm(2) + loss = criterion(net(inputs), labels) + (config["proximal_mu"] / 2) * + proximal_term + + With `global_params` being a copy of the parameters before the training takes + place. + + .. code:: python + + global_params = copy.deepcopy(net).parameters() + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. In case `min_fit_clients` + is larger than `fraction_fit * available_clients`, `min_fit_clients` + will still be sampled. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. In case `min_evaluate_clients` + is larger than `fraction_evaluate * available_clients`, + `min_evaluate_clients` will still be sampled. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters, optional + Initial global model parameters. + fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + proximal_mu : float + The weight of the proximal term used in the optimization. 0.0 makes + this strategy equivalent to FedAvg, and the higher the coefficient, the more + regularization will be used (that is, the client parameters will need to be + closer to the server parameters during training). + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( self, *, @@ -53,78 +125,6 @@ def __init__( evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, proximal_mu: float, ) -> None: - r"""Federated Optimization strategy. - - Implementation based on https://arxiv.org/abs/1812.06127 - - The strategy in itself will not be different than FedAvg, the client needs to - be adjusted. - A proximal term needs to be added to the loss function during the training: - - .. math:: - \\frac{\\mu}{2} || w - w^t ||^2 - - Where $w^t$ are the global parameters and $w$ are the local weights the function - will be optimized with. - - In PyTorch, for example, the loss would go from: - - .. code:: python - - loss = criterion(net(inputs), labels) - - To: - - .. code:: python - - for local_weights, global_weights in zip(net.parameters(), global_params): - proximal_term += (local_weights - global_weights).norm(2) - loss = criterion(net(inputs), labels) + (config["proximal_mu"] / 2) * - proximal_term - - With `global_params` being a copy of the parameters before the training takes - place. - - .. code:: python - - global_params = copy.deepcopy(net).parameters() - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. In case `min_fit_clients` - is larger than `fraction_fit * available_clients`, `min_fit_clients` - will still be sampled. Defaults to 1.0. - fraction_evaluate : float, optional - Fraction of clients used during validation. In case `min_evaluate_clients` - is larger than `fraction_evaluate * available_clients`, - `min_evaluate_clients` will still be sampled. Defaults to 1.0. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters, optional - Initial global model parameters. - fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - proximal_mu : float - The weight of the proximal term used in the optimization. 0.0 makes - this strategy equivalent to FedAvg, and the higher the coefficient, the more - regularization will be used (that is, the client parameters will need to be - closer to the server parameters during training). - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/fedtrimmedavg.py b/src/py/flwr/server/strategy/fedtrimmedavg.py index 8ec3ce89e98e..96b0d35e7a61 100644 --- a/src/py/flwr/server/strategy/fedtrimmedavg.py +++ b/src/py/flwr/server/strategy/fedtrimmedavg.py @@ -35,10 +35,36 @@ from .fedavg import FedAvg +# pylint: disable=line-too-long class FedTrimmedAvg(FedAvg): """Federated Averaging with Trimmed Mean [Dong Yin, et al., 2021]. - Paper: https://arxiv.org/abs/1803.01498 + Implemented based on: https://arxiv.org/abs/1803.01498 + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters, optional + Initial global model parameters. + beta : float, optional + Fraction to cut off of both tails of the distribution. Defaults to 0.2. """ # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long @@ -64,33 +90,6 @@ def __init__( evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, beta: float = 0.2, ) -> None: - """Federated Averaging with Trimmed Mean [Dong Yin, et al., 2021]. - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. Defaults to 0.1. - fraction_evaluate : float, optional - Fraction of clients used during validation. Defaults to 0.1. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters, optional - Initial global model parameters. - beta : float, optional - Fraction to cut off of both tails of the distribution. Defaults to 0.2. - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py new file mode 100644 index 000000000000..e2707b02d19d --- /dev/null +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -0,0 +1,142 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Federated XGBoost cyclic aggregation strategy.""" + + +from logging import WARNING +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy + +from .fedavg import FedAvg + + +class FedXgbCyclic(FedAvg): + """Configurable FedXgbCyclic strategy implementation.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + def __init__( + self, + **kwargs: Any, + ): + self.global_model: Optional[bytes] = None + super().__init__(**kwargs) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using bagging.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Fetch the client model from last round as global model + for _, fit_res in results: + update = fit_res.parameters.tensors + for bst in update: + self.global_model = bst + + return ( + Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]), + {}, + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation metrics using average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.evaluate_metrics_aggregation_fn: + eval_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No evaluate_metrics_aggregation_fn provided") + + return 0, metrics_aggregated + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + config = {} + if self.on_fit_config_fn is not None: + # Custom fit config function provided + config = self.on_fit_config_fn(server_round) + fit_ins = FitIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_fit_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, fit_ins) for client in sampled_clients] + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + # Do not configure federated evaluation if fraction eval is 0. + if self.fraction_evaluate == 0.0: + return [] + + # Parameters and config + config = {} + if self.on_evaluate_config_fn is not None: + # Custom evaluation config function provided + config = self.on_evaluate_config_fn(server_round) + evaluate_ins = EvaluateIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_evaluation_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, evaluate_ins) for client in sampled_clients] diff --git a/src/py/flwr/server/strategy/fedxgb_nn_avg.py b/src/py/flwr/server/strategy/fedxgb_nn_avg.py index f300633d0d9f..8dedc925f350 100644 --- a/src/py/flwr/server/strategy/fedxgb_nn_avg.py +++ b/src/py/flwr/server/strategy/fedxgb_nn_avg.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays -from flwr.common.logger import log +from flwr.common.logger import log, warn_deprecated_feature from flwr.server.client_proxy import ClientProxy from .aggregate import aggregate @@ -33,7 +33,13 @@ class FedXgbNnAvg(FedAvg): - """Configurable FedXgbNnAvg strategy implementation.""" + """Configurable FedXgbNnAvg strategy implementation. + + Warning + ------- + This strategy is deprecated, but a copy of it is available in Flower Baselines: + https://github.com/adap/flower/tree/main/baselines/hfedxgboost. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: """Federated XGBoost [Ma et al., 2023] strategy. @@ -41,6 +47,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: Implementation based on https://arxiv.org/abs/2304.07537. """ super().__init__(*args, **kwargs) + warn_deprecated_feature("`FedXgbNnAvg` strategy") def __repr__(self) -> str: """Compute a string representation of the strategy.""" diff --git a/src/py/flwr/server/strategy/fedyogi.py b/src/py/flwr/server/strategy/fedyogi.py index 245090534c7c..7c77aab7ae73 100644 --- a/src/py/flwr/server/strategy/fedyogi.py +++ b/src/py/flwr/server/strategy/fedyogi.py @@ -36,12 +36,50 @@ from .fedopt import FedOpt +# pylint: disable=line-too-long class FedYogi(FedOpt): """FedYogi [Reddi et al., 2020] strategy. - Adaptive Federated Optimization using Yogi. - - Paper: https://arxiv.org/abs/2003.00295 + Implementation based on https://arxiv.org/abs/2003.00295v5 + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[ + Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters + Initial global model parameters. + fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + eta : float, optional + Server-side learning rate. Defaults to 1e-2. + eta_l : float, optional + Client-side learning rate. Defaults to 0.0316. + beta_1 : float, optional + Momentum parameter. Defaults to 0.9. + beta_2 : float, optional + Second moment parameter. Defaults to 0.99. + tau : float, optional + Controls the algorithm's degree of adaptability. + Defaults to 1e-3. """ # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals, line-too-long @@ -71,48 +109,6 @@ def __init__( beta_2: float = 0.99, tau: float = 1e-3, ) -> None: - """Federated learning strategy using Yogi on server-side. - - Implementation based on https://arxiv.org/abs/2003.00295v5 - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. Defaults to 1.0. - fraction_evaluate : float, optional - Fraction of clients used during validation. Defaults to 1.0. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters - Initial global model parameters. - fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] - Metrics aggregation function, optional. - eta : float, optional - Server-side learning rate. Defaults to 1e-1. - eta_l : float, optional - Client-side learning rate. Defaults to 1e-1. - beta_1 : float, optional - Momentum parameter. Defaults to 0.9. - beta_2 : float, optional - Second moment parameter. Defaults to 0.99. - tau : float, optional - Controls the algorithm's degree of adaptability. - Defaults to 1e-9. - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/krum.py b/src/py/flwr/server/strategy/krum.py index d7f15531902f..16eb5212940e 100644 --- a/src/py/flwr/server/strategy/krum.py +++ b/src/py/flwr/server/strategy/krum.py @@ -39,10 +39,42 @@ from .fedavg import FedAvg +# pylint: disable=line-too-long class Krum(FedAvg): - """Configurable Krum strategy implementation.""" + """Krum [Blanchard et al., 2017] strategy. - # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + Implementation based on https://arxiv.org/abs/1703.02757 + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + num_malicious_clients : int, optional + Number of malicious clients in the system. Defaults to 0. + num_clients_to_keep : int, optional + Number of clients to keep before averaging (MultiKrum). Defaults to 0, in + that case classical Krum is applied. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters, optional + Initial global model parameters. + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( self, *, @@ -66,36 +98,6 @@ def __init__( fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, ) -> None: - """Krum strategy. - - Parameters - ---------- - fraction_fit : float, optional - Fraction of clients used during training. Defaults to 0.1. - fraction_evaluate : float, optional - Fraction of clients used during validation. Defaults to 0.1. - min_fit_clients : int, optional - Minimum number of clients used during training. Defaults to 2. - min_evaluate_clients : int, optional - Minimum number of clients used during validation. Defaults to 2. - min_available_clients : int, optional - Minimum number of total clients in the system. Defaults to 2. - num_malicious_clients : int, optional - Number of malicious clients in the system. Defaults to 0. - num_clients_to_keep : int, optional - Number of clients to keep before averaging (MultiKrum). Defaults to 0, in - that case classical Krum is applied. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] - Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. - accept_failures : bool, optional - Whether or not accept rounds containing failures. Defaults to True. - initial_parameters : Parameters, optional - Initial global model parameters. - """ super().__init__( fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py index 94a67fbcbfae..758e8e608e9f 100644 --- a/src/py/flwr/server/strategy/qfedavg.py +++ b/src/py/flwr/server/strategy/qfedavg.py @@ -185,7 +185,7 @@ def norm_grad(grad_list: NDArrays) -> float: hs_ffl = [] if self.pre_weights is None: - raise Exception("QffedAvg pre_weights are None in aggregate_fit") + raise AttributeError("QffedAvg pre_weights are None in aggregate_fit") weights_before = self.pre_weights eval_result = self.evaluate( diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index fd89a01e4a4e..01dbcf982cce 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -17,7 +17,7 @@ from typing import List, Union -from flwr.proto.task_pb2 import TaskIns, TaskRes +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 # pylint: disable-next=too-many-branches,too-many-statements diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index cab51fbf46de..a93e4fb4d457 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -18,9 +18,17 @@ import unittest from typing import List, Tuple -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes -from flwr.proto.transport_pb2 import ClientMessage, ServerMessage +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import ( # pylint: disable=E0611 + SecureAggregation, + Task, + TaskIns, + TaskRes, +) +from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 + ClientMessage, + ServerMessage, +) from .validator import validate_task_ins_or_res @@ -135,7 +143,7 @@ def create_task_ins( task = TaskIns( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), @@ -162,7 +170,7 @@ def create_task_res( task_res = TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 0bb9290b6911..6a18a258ac60 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -107,8 +107,8 @@ def start_simulation( List `client_id`s for each client. This is only required if `num_clients` is not set. Setting both `num_clients` and `clients_ids` with `len(clients_ids)` not equal to `num_clients` generates an error. - client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, - "num_gpus": 0.0}` CPU and GPU resources for a single client. Supported keys + client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`) + CPU and GPU resources for a single client. Supported keys are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by `num_gpus`, as well as using custom resources, please consult the Ray documentation. @@ -160,7 +160,7 @@ def start_simulation( ------- hist : flwr.server.history.History Object containing metrics from training. - """ + """ # noqa: E501 # pylint: disable-msg=too-many-locals event( EventType.START_SIMULATION_ENTER, @@ -314,18 +314,30 @@ def update_resources(f_stop: threading.Event) -> None: log(ERROR, traceback.format_exc()) log( ERROR, - "Your simulation crashed :(. This could be because of several reasons." + "Your simulation crashed :(. This could be because of several reasons. " "The most common are: " + "\n\t > Sometimes, issues in the simulation code itself can cause crashes. " + "It's always a good idea to double-check your code for any potential bugs " + "or inconsistencies that might be contributing to the problem. " + "For example: " + "\n\t\t - You might be using a class attribute in your clients that " + "hasn't been defined." + "\n\t\t - There could be an incorrect method call to a 3rd party library " + "(e.g., PyTorch)." + "\n\t\t - The return types of methods in your clients/strategies might be " + "incorrect." "\n\t > Your system couldn't fit a single VirtualClient: try lowering " "`client_resources`." "\n\t > All the actors in your pool crashed. This could be because: " "\n\t\t - You clients hit an out-of-memory (OOM) error and actors couldn't " "recover from it. Try launching your simulation with more generous " "`client_resources` setting (i.e. it seems %s is " - "not enough for your workload). Use fewer concurrent actors. " + "not enough for your run). Use fewer concurrent actors. " "\n\t\t - You were running a multi-node simulation and all worker nodes " "disconnected. The head node might still be alive but cannot accommodate " - "any actor with resources: %s.", + "any actor with resources: %s." + "\nTake a look at the Flower simulation examples for guidance " + ".", client_resources, client_resources, ) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 57ea0ed7b187..38af3f08daa2 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -27,7 +27,7 @@ from flwr import common from flwr.client import Client, ClientFn -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.common.logger import log from flwr.simulation.ray_transport.utils import check_clientfn_returns_client @@ -61,9 +61,9 @@ def run( client_fn: ClientFn, job_fn: JobFn, cid: str, - state: WorkloadState, - ) -> Tuple[str, ClientRes, WorkloadState]: - """Run a client workload.""" + state: RunState, + ) -> Tuple[str, ClientRes, RunState]: + """Run a client run.""" # Execute tasks and return result # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy @@ -76,16 +76,15 @@ def run( job_results = job_fn(client) # Retrieve state (potentially updated) updated_state = client.get_state() - print(f"Actor finishing ({cid}) !!!: {updated_state = }") except Exception as ex: client_trace = traceback.format_exc() message = ( - "\n\tSomething went wrong when running your client workload." + "\n\tSomething went wrong when running your client run." "\n\tClient " + cid + " crashed when the " + self.__class__.__name__ - + " was running its workload." + + " was running its run." "\n\tException triggered on the client side: " + client_trace, ) raise ClientException(str(message)) from ex @@ -95,7 +94,7 @@ def run( @ray.remote class DefaultActor(VirtualClientEngineActor): - """A Ray Actor class that runs client workloads. + """A Ray Actor class that runs client runs. Parameters ---------- @@ -238,10 +237,8 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit( - self, fn: Any, value: Tuple[ClientFn, JobFn, str, WorkloadState] - ) -> None: - """Take idle actor and assign it a client workload. + def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RunState]) -> None: + """Take idle actor and assign it a client run. Submit a job to an actor by first removing it from the list of idle actors, then check if this actor was flagged to be removed from the pool @@ -258,7 +255,7 @@ def submit( self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, WorkloadState] + self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, RunState] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -298,7 +295,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, WorkloadState]: + def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RunState]: """Fetch result and updated state for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -308,7 +305,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, WorkloadState]: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore res_cid, res, updated_state = ray.get( future - ) # type: (str, ClientRes, WorkloadState) + ) # type: (str, ClientRes, RunState) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -412,7 +409,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, WorkloadState]: + ) -> Tuple[ClientRes, RunState]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -424,5 +421,5 @@ def get_client_result( break # Fetch result belonging to the VirtualClient calling this method - # Return both result from tasks and (potentially) updated workload state + # Return both result from tasks and (potentially) updated run state return self._fetch_future_result(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 9596acaf1d91..5c05850dfd2f 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -29,7 +29,7 @@ maybe_call_get_parameters, maybe_call_get_properties, ) -from flwr.client.workload_state import WorkloadState +from flwr.client.node_state import NodeState from flwr.common.logger import log from flwr.server.client_proxy import ClientProxy from flwr.simulation.ray_transport.ray_actor import ( @@ -129,19 +129,34 @@ def __init__( super().__init__(cid) self.client_fn = client_fn self.actor_pool = actor_pool + self.proxy_state = NodeState() def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: + # The VCE is not exposed to TaskIns, it won't handle multilple runs + # For the time being, fixing run_id is a small compromise + # This will be one of the first points to address integrating VCE + DriverAPI + run_id = 0 + + # Register state + self.proxy_state.register_runstate(run_id=run_id) + + # Retrieve state + state = self.proxy_state.retrieve_runstate(run_id=run_id) + try: self.actor_pool.submit_client_job( lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (self.client_fn, job_fn, self.cid, WorkloadState(state={})), + (self.client_fn, job_fn, self.cid, state), ) - res, _ = self.actor_pool.get_client_result(self.cid, timeout) + res, updated_state = self.actor_pool.get_client_result(self.cid, timeout) + + # Update state + self.proxy_state.update_runstate(run_id=run_id, run_state=updated_state) except Exception as ex: if self.actor_pool.num_actors == 0: # At this point we want to stop the simulation. - # since no more client workloads will be executed + # since no more client runs will be executed log(ERROR, "ActorPool is empty!!!") log(ERROR, traceback.format_exc()) log(ERROR, ex) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 44cb4ec70471..9df71635b949 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -22,7 +22,7 @@ import ray from flwr.client import Client, NumPyClient -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.common import Code, GetPropertiesRes, Status from flwr.simulation.ray_transport.ray_actor import ( ClientRes, @@ -46,13 +46,16 @@ def get_dummy_client(cid: str) -> Client: return DummyClient(cid).to_client() -# A dummy workload +# A dummy run def job_fn(cid: str) -> JobFn: # pragma: no cover """Construct a simple job with cid dependency.""" def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument result = int(cid) * pi + # store something in state + client.numpy_client.state.state["result"] = str(result) # type: ignore + # now let's convert it to a GetPropertiesRes response return GetPropertiesRes( status=Status(Code(0), message="test"), properties={"result": result} @@ -109,28 +112,40 @@ def test_cid_consistency_one_at_a_time() -> None: ray.shutdown() -def test_cid_consistency_all_submit_first() -> None: +def test_cid_consistency_all_submit_first_run_consistency() -> None: """Test that ClientProxies get the result of client job they submit. - All jobs are submitted at the same time. Then fetched one at a time. + All jobs are submitted at the same time. Then fetched one at a time. This also tests + NodeState (at each Proxy) and RunState basic functionality. """ proxies, _ = prep() + run_id = 0 # submit all jobs (collect later) shuffle(proxies) for prox in proxies: + # Register state + prox.proxy_state.register_runstate(run_id=run_id) + # Retrieve state + state = prox.proxy_state.retrieve_runstate(run_id=run_id) + job = job_fn(prox.cid) prox.actor_pool.submit_client_job( lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (prox.client_fn, job, prox.cid, WorkloadState(state={})), + (prox.client_fn, job, prox.cid, state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: - res, _ = prox.actor_pool.get_client_result(prox.cid, timeout=None) + res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None) + prox.proxy_state.update_runstate(run_id, run_state=updated_state) res = cast(GetPropertiesRes, res) assert int(prox.cid) * pi == res.properties["result"] + assert ( + str(int(prox.cid) * pi) + == prox.proxy_state.retrieve_runstate(run_id).state["result"] + ) ray.shutdown() @@ -147,7 +162,7 @@ def test_cid_consistency_without_proxies() -> None: job = job_fn(cid) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, WorkloadState(state={})), + (get_dummy_client, job, cid, RunState(state={})), ) # fetch results one at a time diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index c8e6aa6cbe21..41aa8049eaf0 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -37,7 +37,7 @@ def enable_tf_gpu_growth() -> None: # the same GPU. # Luckily we can disable this behavior by enabling memory growth # on the GPU. In this way, VRAM allocated to the processes grows based - # on the needs for the workload. (this is for instance the default + # on the needs for the run. (this is for instance the default # behavior in PyTorch) # While this behavior is critical for Actors, you'll likely need it # as well in your main process (where the server runs and might evaluate diff --git a/src/py/flwr_experimental/ops/__init__.py b/src/py/flwr_experimental/ops/__init__.py index b56c757e0207..bad31028e68c 100644 --- a/src/py/flwr_experimental/ops/__init__.py +++ b/src/py/flwr_experimental/ops/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # ============================================================================== """Flower ops provides an opinionated way to provision necessary compute -infrastructure for running Flower workloads.""" +infrastructure for running Flower runs.""" diff --git a/src/py/flwr_tool/init_py_check.py b/src/py/flwr_tool/init_py_check.py index 8cdc2e0ab5be..67425139f991 100755 --- a/src/py/flwr_tool/init_py_check.py +++ b/src/py/flwr_tool/init_py_check.py @@ -36,7 +36,7 @@ def check_missing_init_files(absolute_path: str) -> None: if __name__ == "__main__": if len(sys.argv) == 0: - raise Exception( + raise Exception( # pylint: disable=W0719 "Please provide at least one directory path relative to your current working directory." ) for i, _ in enumerate(sys.argv): diff --git a/src/py/flwr_tool/protoc.py b/src/py/flwr_tool/protoc.py index 5d3ce942c1e0..b0b078c2eae4 100644 --- a/src/py/flwr_tool/protoc.py +++ b/src/py/flwr_tool/protoc.py @@ -51,7 +51,7 @@ def compile_all() -> None: exit_code = protoc.main(command) if exit_code != 0: - raise Exception(f"Error: {command} failed") + raise Exception(f"Error: {command} failed") # pylint: disable=W0719 if __name__ == "__main__": diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 57ca3ff423c2..607d808c8497 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 5 + assert len(PROTO_FILES) == 6 diff --git a/src/py/flwr_tool/update_changelog.py b/src/py/flwr_tool/update_changelog.py new file mode 100644 index 000000000000..bbd5c7f3dc7b --- /dev/null +++ b/src/py/flwr_tool/update_changelog.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This module is used to update the changelog.""" + + +import re +from sys import argv + +from github import Github + +REPO_NAME = "adap/flower" +CHANGELOG_FILE = "doc/source/ref-changelog.md" +CHANGELOG_SECTION_HEADER = "### Changelog entry" + + +def _get_latest_tag(gh_api): + """Retrieve the latest tag from the GitHub repository.""" + repo = gh_api.get_repo(REPO_NAME) + tags = repo.get_tags() + return tags[0] if tags.totalCount > 0 else None + + +def _get_pull_requests_since_tag(gh_api, tag): + """Get a list of pull requests merged into the main branch since a given tag.""" + repo = gh_api.get_repo(REPO_NAME) + commits = {commit.sha for commit in repo.compare(tag.commit.sha, "main").commits} + prs = set() + for pr_info in repo.get_pulls( + state="closed", sort="created", direction="desc", base="main" + ): + if pr_info.merge_commit_sha in commits: + prs.add(pr_info) + if len(prs) == len(commits): + break + return prs + + +def _format_pr_reference(title, number, url): + """Format a pull request reference as a markdown list item.""" + return f"- **{title}** ([#{number}]({url}))" + + +def _extract_changelog_entry(pr_info): + """Extract the changelog entry from a pull request's body.""" + if not pr_info.body: + return None, "general" + + entry_match = re.search( + f"{CHANGELOG_SECTION_HEADER}(.+?)(?=##|$)", pr_info.body, re.DOTALL + ) + if not entry_match: + return None, "general" + + entry_text = entry_match.group(1).strip() + + # Remove markdown comments + entry_text = re.sub(r"", "", entry_text, flags=re.DOTALL).strip() + + token_markers = { + "general": "", + "skip": "", + "baselines": "", + "examples": "", + "sdk": "", + "simulations": "", + } + + # Find the token based on the presence of its marker in entry_text + token = next( + (token for token, marker in token_markers.items() if marker in entry_text), None + ) + + return entry_text, token + + +def _update_changelog(prs): + """Update the changelog file with entries from provided pull requests.""" + with open(CHANGELOG_FILE, "r+", encoding="utf-8") as file: + content = file.read() + unreleased_index = content.find("## Unreleased") + + if unreleased_index == -1: + print("Unreleased header not found in the changelog.") + return + + # Find the end of the Unreleased section + next_header_index = content.find("##", unreleased_index + 1) + next_header_index = ( + next_header_index if next_header_index != -1 else len(content) + ) + + for pr_info in prs: + pr_entry_text, category = _extract_changelog_entry(pr_info) + + # Skip if PR should be skipped or already in changelog + if category == "skip" or f"#{pr_info.number}]" in content: + continue + + pr_reference = _format_pr_reference( + pr_info.title, pr_info.number, pr_info.html_url + ) + + # Process based on category + if category in ["general", "baselines", "examples", "sdk", "simulations"]: + entry_title = _get_category_title(category) + content = _update_entry( + content, + entry_title, + pr_info, + unreleased_index, + next_header_index, + ) + + elif pr_entry_text: + content = _insert_new_entry( + content, pr_info, pr_reference, pr_entry_text, unreleased_index + ) + + else: + content = _insert_entry_no_desc(content, pr_reference, unreleased_index) + + next_header_index = content.find("##", unreleased_index + 1) + next_header_index = ( + next_header_index if next_header_index != -1 else len(content) + ) + + # Finalize content update + file.seek(0) + file.write(content) + file.truncate() + + print("Changelog updated.") + + +def _get_category_title(category): + """Get the title of a changelog section based on its category.""" + headers = { + "general": "General improvements", + "baselines": "General updates to Flower Baselines", + "examples": "General updates to Flower Examples", + "sdk": "General updates to Flower SDKs", + "simulations": "General updates to Flower Simulations", + } + return headers.get(category, "") + + +def _update_entry( + content, category_title, pr_info, unreleased_index, next_header_index +): + """Update a specific section in the changelog content.""" + if ( + section_index := content.find( + category_title, unreleased_index, next_header_index + ) + ) != -1: + newline_index = content.find("\n", section_index) + closing_parenthesis_index = content.rfind(")", unreleased_index, newline_index) + updated_entry = f", [{pr_info.number}]({pr_info.html_url})" + content = ( + content[:closing_parenthesis_index] + + updated_entry + + content[closing_parenthesis_index:] + ) + else: + new_section = ( + f"\n- **{category_title}** ([#{pr_info.number}]({pr_info.html_url}))\n" + ) + insert_index = content.find("\n", unreleased_index) + 1 + content = content[:insert_index] + new_section + content[insert_index:] + return content + + +def _insert_new_entry(content, pr_info, pr_reference, pr_entry_text, unreleased_index): + """Insert a new entry into the changelog.""" + if (existing_entry_start := content.find(pr_entry_text)) != -1: + pr_ref_end = content.rfind("\n", 0, existing_entry_start) + updated_entry = ( + f"{content[pr_ref_end]}\n, [{pr_info.number}]({pr_info.html_url})" + ) + content = content[:pr_ref_end] + updated_entry + content[existing_entry_start:] + else: + insert_index = content.find("\n", unreleased_index) + 1 + content = ( + content[:insert_index] + + pr_reference + + "\n " + + pr_entry_text + + "\n" + + content[insert_index:] + ) + return content + + +def _insert_entry_no_desc(content, pr_reference, unreleased_index): + """Insert a changelog entry for a pull request with no specific description.""" + insert_index = content.find("\n", unreleased_index) + 1 + content = ( + content[:insert_index] + "\n" + pr_reference + "\n" + content[insert_index:] + ) + return content + + +def main(): + """Update changelog using the descriptions of PRs since the latest tag.""" + # Initialize GitHub Client with provided token (as argument) + gh_api = Github(argv[1]) + latest_tag = _get_latest_tag(gh_api) + if not latest_tag: + print("No tags found in the repository.") + return + + prs = _get_pull_requests_since_tag(gh_api, latest_tag) + _update_changelog(prs) + + +if __name__ == "__main__": + main()