diff --git a/doc/source/tutorial-quickstart-xgboost.rst b/doc/source/tutorial-quickstart-xgboost.rst index be7094614c63..111920d5602b 100644 --- a/doc/source/tutorial-quickstart-xgboost.rst +++ b/doc/source/tutorial-quickstart-xgboost.rst @@ -7,6 +7,773 @@ Quickstart XGBoost .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with XGBoost to train classification models on trees. -Let's build a horizontal federated learning system using XGBoost and Flower! +Federated XGBoost +------------- -Please refer to the `full code example `_ to learn more. +EXtreme Gradient Boosting (**XGBoost**) is a robust and efficient implementation of gradient-boosted decision tree (**GBDT**), that maximises the computational boundaries for boosted tree methods. +It's primarily designed to enhance both the performance and computational speed of machine learning models. +In XGBoost, trees are constructed concurrently, unlike the sequential approach taken by GBDT. + +Often, for tabular data on medium-sized datasets with fewer than 10k training examples, XGBoost surpasses the results of deep learning techniques. + +Why federated XGBoost? +~~~~~~~~ + +Indeed, as the demand for data privacy and decentralized learning grows, there's an increasing requirement to implement federated XGBoost systems for specialised applications, like survival analysis and financial fraud detection. + +Federated learning ensures that raw data remains on the local device, making it an attractive approach for sensitive domains where data security and privacy are paramount. +Given the robustness and efficiency of XGBoost, combining it with federated learning offers a promising solution for these specific challenges. + +In this tutorial we will learn how to train a federated XGBoost model on HIGGS dataset using Flower and :code:`xgboost` package. +We use a simple example (`full code xgboost-quickstart `_) with two *clients* and one *server* +to demonstrate how federated XGBoost works, +and then we dive into a more complex example (`full code xgboost-comprehensive `_) to run various experiments. + + +Environment Setup +------------- + +First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. + +We first need to install Flower and Flower Datasets. You can do this by running : + +.. code-block:: shell + + $ pip install flwr flwr-datasets + +Since we want to use :code:`xgboost` package to build up XGBoost trees, let's go ahead and install :code:`xgboost`: + +.. code-block:: shell + + $ pip install xgboost + + +Flower Client +------------- + +*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. +Now that we have all our dependencies installed, let's run a simple distributed training with two clients and one server. + +In a file called :code:`client.py`, import xgboost, Flower, Flower Datasets and other related functions: + +.. code-block:: python + + import argparse + from typing import Union + from logging import INFO + from datasets import Dataset, DatasetDict + import xgboost as xgb + + import flwr as fl + from flwr_datasets import FederatedDataset + from flwr.common.logger import log + from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + Parameters, + Status, + ) + from flwr_datasets.partitioner import IidPartitioner + +Dataset partition and hyper-parameter selection +~~~~~~~~ + +Prior to local training, we require loading the HIGGS dataset from Flower Datasets and conduct data partitioning for FL: + +.. code-block:: python + + # Load (HIGGS) dataset and conduct partitioning + # We use a small subset (num_partitions=30) of the dataset for demonstration to speed up the data loading process. + partitioner = IidPartitioner(num_partitions=30) + fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) + + # Load the partition for this `node_id` + partition = fds.load_partition(node_id=args.node_id, split="train") + partition.set_format("numpy") + +In this example, we split the dataset into two partitions with uniform distribution (:code:`IidPartitioner(num_partitions=2)`). +Then, we load the partition for the given client based on :code:`node_id`: + +.. code-block:: python + + # We first define arguments parser for user to specify the client/node ID. + parser = argparse.ArgumentParser() + parser.add_argument( + "--node-id", + default=0, + type=int, + help="Node ID used for the current client.", + ) + args = parser.parse_args() + + # Load the partition for this `node_id`. + partition = fds.load_partition(idx=args.node_id, split="train") + partition.set_format("numpy") + +After that, we do train/test splitting on the given partition (client's local data), and transform data format for :code:`xgboost` package. + +.. code-block:: python + + # Train/test splitting + train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=0.2, seed=42 + ) + + # Reformat data to DMatrix for xgboost + train_dmatrix = transform_dataset_to_dmatrix(train_data) + valid_dmatrix = transform_dataset_to_dmatrix(valid_data) + +The functions of :code:`train_test_split` and :code:`transform_dataset_to_dmatrix` are defined as below: + +.. code-block:: python + + # Define data partitioning related functions + def train_test_split(partition: Dataset, test_fraction: float, seed: int): + """Split the data into train and validation set given split rate.""" + train_test = partition.train_test_split(test_size=test_fraction, seed=seed) + partition_train = train_test["train"] + partition_test = train_test["test"] + + num_train = len(partition_train) + num_test = len(partition_test) + + return partition_train, partition_test, num_train, num_test + + + def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: + """Transform dataset to DMatrix format for xgboost.""" + x = data["inputs"] + y = data["label"] + new_data = xgb.DMatrix(x, label=y) + return new_data + +Finally, we define the hyper-parameters used for XGBoost training. + +.. code-block:: python + + num_local_round = 1 + params = { + "objective": "binary:logistic", + "eta": 0.1, # lr + "max_depth": 8, + "eval_metric": "auc", + "nthread": 16, + "num_parallel_tree": 1, + "subsample": 1, + "tree_method": "hist", + } + +The :code:`num_local_round` represents the number of iterations for local tree boost. +We use CPU for the training in default. +One can shift it to GPU by setting :code:`tree_method` to :code:`gpu_hist`. +We use AUC as evaluation metric. + + +Flower client definition for XGBoost +~~~~~~~~ + +After loading the dataset we define the Flower client. +We follow the general rule to define :code:`XgbClient` class inherited from :code:`fl.client.Client`. + +.. code-block:: python + + class XgbClient(fl.client.Client): + def __init__(self): + self.bst = None + self.config = None + +The :code:`self.bst` is used to keep the Booster objects that remain consistent across rounds, +allowing them to store predictions from trees integrated in earlier rounds and maintain other essential data structures for training. + +Then, we override :code:`get_parameters`, :code:`fit` and :code:`evaluate` methods insides :code:`XgbClient` class as follows. + +.. code-block:: python + + def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: + _ = (self, ins) + return GetParametersRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[]), + ) + +Unlike neural network training, XGBoost trees are not started from a specified random weights. +In this case, we do not use :code:`get_parameters` and :code:`set_parameters` to initialise model parameters for XGBoost. +As a result, let's return an empty tensor in :code:`get_parameters` when it is called by the server at the first round. + +.. code-block:: python + + def fit(self, ins: FitIns) -> FitRes: + if not self.bst: + # First round local training + log(INFO, "Start training at round 1") + bst = xgb.train( + params, + train_dmatrix, + num_boost_round=num_local_round, + evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + ) + self.config = bst.save_config() + self.bst = bst + else: + for item in ins.parameters.tensors: + global_model = bytearray(item) + + # Load global model into booster + self.bst.load_model(global_model) + self.bst.load_config(self.config) + + bst = self._local_boost() + + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=num_train, + metrics={}, + ) + +In :code:`fit`, at the first round, we call :code:`xgb.train()` to build up the first set of trees. +the returned Booster object and config are stored in :code:`self.bst` and :code:`self.config`, respectively. +From the second round, we load the global model sent from server to :code:`self.bst`, +and then update model weights on local training data with function :code:`local_boost` as follows: + +.. code-block:: python + + def _local_boost(self): + # Update trees based on local training data. + for i in range(num_local_round): + self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + + # Extract the last N=num_local_round trees for sever aggregation + bst = self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + +Given :code:`num_local_round`, we update trees by calling :code:`self.bst.update` method. +After training, the last :code:`N=num_local_round` trees will be extracted to send to the server. + +.. code-block:: python + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + eval_results = self.bst.eval_set( + evals=[(valid_dmatrix, "valid")], + iteration=self.bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=num_val, + metrics={"AUC": auc}, + ) + +In :code:`evaluate`, we call :code:`self.bst.eval_set` function to conduct evaluation on valid set. +The AUC value will be returned. + +Now, we can create an instance of our class :code:`XgbClient` and add one line to actually run this client: + +.. code-block:: python + + fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) + +That's it for the client. We only have to implement :code:`Client`and call :code:`fl.client.start_client()`. +The string :code:`"[::]:8080"` tells the client which server to connect to. +In our case we can run the server and the client on the same machine, therefore we use +:code:`"[::]:8080"`. If we run a truly federated workload with the server and +clients running on different machines, all that needs to change is the +:code:`server_address` we point the client at. + + +Flower Server +------------- + +These updates are then sent to the *server* which will aggregate them to produce a better model. +Finally, the *server* sends this improved version of the model back to each *client* to finish a complete FL round. + +In a file named :code:`server.py`, import Flower and FedXgbBagging from :code:`flwr.server.strategy`. + +We first define a strategy for XGBoost bagging aggregation. + +.. code-block:: python + + # Define strategy + strategy = FedXgbBagging( + fraction_fit=1.0, + min_fit_clients=2, + min_available_clients=2, + min_evaluate_clients=2, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + ) + + def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + +We use two clients for this example. +An :code:`evaluate_metrics_aggregation` function is defined to collect and wighted average the AUC values from clients. + +Then, we start the server: + +.. code-block:: python + + # Start Flower server + fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=num_rounds), + strategy=strategy, + ) + +Tree-based bagging aggregation +~~~~~~~~ + +You must be curious about how bagging aggregation works. Let's look into the details. + +In file :code:`flwr.server.strategy.fedxgb_bagging.py`, we define :code:`FedXgbBagging` inherited from :code:`flwr.server.strategy.FedAvg`. +Then, we override the :code:`aggregate_fit`, :code:`aggregate_evaluate` and :code:`evaluate` methods as follows: + +.. code-block:: python + + import json + from logging import WARNING + from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + + from flwr.common import EvaluateRes, FitRes, Parameters, Scalar + from flwr.common.logger import log + from flwr.server.client_proxy import ClientProxy + + from .fedavg import FedAvg + + + class FedXgbBagging(FedAvg): + """Configurable FedXgbBagging strategy implementation.""" + + def __init__( + self, + evaluate_function: Optional[ + Callable[ + [int, Parameters, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + **kwargs: Any, + ): + self.evaluate_function = evaluate_function + self.global_model: Optional[bytes] = None + super().__init__(**kwargs) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using bagging.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate all the client trees + global_model = self.global_model + for _, fit_res in results: + update = fit_res.parameters.tensors + for bst in update: + global_model = aggregate(global_model, bst) + + self.global_model = global_model + + return ( + Parameters(tensor_type="", tensors=[cast(bytes, global_model)]), + {}, + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation metrics using average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.evaluate_metrics_aggregation_fn: + eval_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No evaluate_metrics_aggregation_fn provided") + + return 0, metrics_aggregated + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function.""" + if self.evaluate_function is None: + # No evaluation function provided + return None + eval_res = self.evaluate_function(server_round, parameters, {}) + if eval_res is None: + return None + loss, metrics = eval_res + return loss, metrics + +In :code:`aggregate_fit`, we sequentially aggregate the clients' XGBoost trees by calling :code:`aggregate()` function: + +.. code-block:: python + + def aggregate( + bst_prev_org: Optional[bytes], + bst_curr_org: bytes, + ) -> bytes: + """Conduct bagging aggregation for given trees.""" + if not bst_prev_org: + return bst_curr_org + + # Get the tree numbers + tree_num_prev, _ = _get_tree_nums(bst_prev_org) + _, paral_tree_num_curr = _get_tree_nums(bst_curr_org) + + bst_prev = json.loads(bytearray(bst_prev_org)) + bst_curr = json.loads(bytearray(bst_curr_org)) + + bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] = str(tree_num_prev + paral_tree_num_curr) + iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][ + "iteration_indptr" + ] + bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( + iteration_indptr[-1] + paral_tree_num_curr + ) + + # Aggregate new trees + trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"] + for tree_count in range(paral_tree_num_curr): + trees_curr[tree_count]["id"] = tree_num_prev + tree_count + bst_prev["learner"]["gradient_booster"]["model"]["trees"].append( + trees_curr[tree_count] + ) + bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) + + bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8") + + return bst_prev_bytes + + + def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]: + xgb_model = json.loads(bytearray(xgb_model_org)) + # Get the number of trees + tree_num = int( + xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] + ) + # Get the number of parallel trees + paral_tree_num = int( + xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_parallel_tree" + ] + ) + return tree_num, paral_tree_num + +In this function, we first fetch the number of trees and the number of parallel trees for the current and previous model +by calling :code:`_get_tree_nums`. +Then, the fetched information will be aggregated. +After that, the trees (containing model weights) are aggregated to generate a new tree model. + +After traversal of all clients' models, a new global model is generated, +followed by the serialisation, and sending back to each client. + + +Launch Federated XGBoost! +--------------------------- + +With both client and server ready, we can now run everything and see federated +learning in action. FL systems usually have a server and multiple clients. We +therefore have to start the server first: + +.. code-block:: shell + + $ python3 server.py + +Once the server is running we can start the clients in different terminals. +Open a new terminal and start the first client: + +.. code-block:: shell + + $ python3 client.py --node-id=0 + +Open another terminal and start the second client: + +.. code-block:: shell + + $ python3 client.py --node-id=1 + +Each client will have its own dataset. +You should now see how the training does in the very first terminal (the one that started the server): + +.. code-block:: shell + + INFO flwr 2023-11-20 11:21:56,454 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=5, round_timeout=None) + INFO flwr 2023-11-20 11:21:56,473 | app.py:176 | Flower ECE: gRPC server running (5 rounds), SSL is disabled + INFO flwr 2023-11-20 11:21:56,473 | server.py:89 | Initializing global parameters + INFO flwr 2023-11-20 11:21:56,473 | server.py:276 | Requesting initial parameters from one random client + INFO flwr 2023-11-20 11:22:38,302 | server.py:280 | Received initial parameters from one random client + INFO flwr 2023-11-20 11:22:38,302 | server.py:91 | Evaluating initial parameters + INFO flwr 2023-11-20 11:22:38,302 | server.py:104 | FL starting + DEBUG flwr 2023-11-20 11:22:38,302 | server.py:222 | fit_round 1: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:38,636 | server.py:236 | fit_round 1 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:38,643 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:38,653 | server.py:187 | evaluate_round 1 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:38,653 | server.py:222 | fit_round 2: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:38,721 | server.py:236 | fit_round 2 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:38,745 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:38,756 | server.py:187 | evaluate_round 2 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:38,756 | server.py:222 | fit_round 3: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:38,831 | server.py:236 | fit_round 3 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:38,868 | server.py:173 | evaluate_round 3: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:38,881 | server.py:187 | evaluate_round 3 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:38,881 | server.py:222 | fit_round 4: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:38,960 | server.py:236 | fit_round 4 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:39,012 | server.py:173 | evaluate_round 4: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:39,026 | server.py:187 | evaluate_round 4 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:39,026 | server.py:222 | fit_round 5: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:39,111 | server.py:236 | fit_round 5 received 2 results and 0 failures + DEBUG flwr 2023-11-20 11:22:39,177 | server.py:173 | evaluate_round 5: strategy sampled 2 clients (out of 2) + DEBUG flwr 2023-11-20 11:22:39,193 | server.py:187 | evaluate_round 5 received 2 results and 0 failures + INFO flwr 2023-11-20 11:22:39,193 | server.py:153 | FL finished in 0.8905023969999988 + INFO flwr 2023-11-20 11:22:39,193 | app.py:226 | app_fit: losses_distributed [(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)] + INFO flwr 2023-11-20 11:22:39,193 | app.py:227 | app_fit: metrics_distributed_fit {} + INFO flwr 2023-11-20 11:22:39,193 | app.py:228 | app_fit: metrics_distributed {'AUC': [(1, 0.7572), (2, 0.7705), (3, 0.77595), (4, 0.78), (5, 0.78385)]} + INFO flwr 2023-11-20 11:22:39,193 | app.py:229 | app_fit: losses_centralized [] + INFO flwr 2023-11-20 11:22:39,193 | app.py:230 | app_fit: metrics_centralized {} + +Congratulations! +You've successfully built and run your first federated XGBoost system. +The AUC values can be checked in :code:`metrics_distributed`. +One can see that the average AUC increases over FL rounds. + +The full `source code `_ for this example can be found in :code:`examples/xgboost-quickstart`. + + +Comprehensive Federated XGBoost +--------------------------- + +Now that you have known how federated XGBoost work with Flower, it's time to run some more comprehensive experiments by customising the experimental settings. +In the xgboost-comprehensive example (`full code `_), +we provide more options to define various experimental setups, including data partitioning and centralised/distributed evaluation. +Let's take a look! + +Customised data partitioning +~~~~~~~~ + +In :code:`dataset.py`, we have a function :code:`instantiate_partitioner` to instantiate the data partitioner +based on the given :code:`num_partitions` and :code:`partitioner_type`. +Currently, we provide four supported partitioner type to simulate the uniformity/non-uniformity in data quantity (uniform, linear, square, exponential). + +.. code-block:: python + + from flwr_datasets.partitioner import ( + IidPartitioner, + LinearPartitioner, + SquarePartitioner, + ExponentialPartitioner, + ) + + CORRELATION_TO_PARTITIONER = { + "uniform": IidPartitioner, + "linear": LinearPartitioner, + "square": SquarePartitioner, + "exponential": ExponentialPartitioner, + } + + + def instantiate_partitioner(partitioner_type: str, num_partitions: int): + """Initialise partitioner based on selected partitioner type and number of + partitions.""" + partitioner = CORRELATION_TO_PARTITIONER[partitioner_type]( + num_partitions=num_partitions + ) + return partitioner + + +Customised centralised/distributed evaluation +~~~~~~~~ + +To facilitate centralised evaluation, we define a function in :code:`server.py`: + +.. code-block:: python + + def get_evaluate_fn(test_data): + """Return a function for centralised evaluation.""" + + def evaluate_fn( + server_round: int, parameters: Parameters, config: Dict[str, Scalar] + ): + # If at the first round, skip the evaluation + if server_round == 0: + return 0, {} + else: + bst = xgb.Booster(params=params) + for para in parameters.tensors: + para_b = bytearray(para) + + # Load global model + bst.load_model(para_b) + # Run evaluation + eval_results = bst.eval_set( + evals=[(test_data, "valid")], + iteration=bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + log(INFO, f"AUC = {auc} at round {server_round}") + + return 0, {"AUC": auc} + + return evaluate_fn + +This function returns a evaluation function which instantiates a :code:`Booster` object and loads the global model weights to it. +The evaluation is conducted by calling :code:`eval_set()` method, and the tested AUC value is reported. + +As for distributed evaluation on the clients, it's same as the quick-start example by +overriding the :code:`evaluate()` method insides the :code:`XgbClient` class in :code:`client.py`. + +Arguments parser +~~~~~~~~ + +In :code:`utils.py`, we define the arguments parsers for clients and server, allowing users to specify different experimental settings. +Let's first see the sever side: + +.. code-block:: python + + import argparse + + + def server_args_parser(): + """Parse arguments to define experimental settings on server side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--pool-size", default=2, type=int, help="Number of total clients." + ) + parser.add_argument( + "--num-rounds", default=5, type=int, help="Number of FL rounds." + ) + parser.add_argument( + "--num-clients-per-round", + default=2, + type=int, + help="Number of clients participate in training each round.", + ) + parser.add_argument( + "--num-evaluate-clients", + default=2, + type=int, + help="Number of clients selected for evaluation.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + + args = parser.parse_args() + return args + +This allows user to specify the number of total clients / FL rounds / participating clients / clients for evaluation, +and evaluation fashion. Note that with :code:`--centralised-eval`, the sever will do centralised evaluation +and all functionalities for client evaluation will be disabled. + +Then, the argument parser on client side: + +.. code-block:: python + + def client_args_parser(): + """Parse arguments to define experimental settings on client side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--num-partitions", default=10, type=int, help="Number of partitions." + ) + parser.add_argument( + "--partitioner-type", + default="uniform", + type=str, + choices=["uniform", "linear", "square", "exponential"], + help="Partitioner types.", + ) + parser.add_argument( + "--node-id", + default=0, + type=int, + help="Node ID used for the current client.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="Seed used for train/test splitting." + ) + parser.add_argument( + "--test-fraction", + default=0.2, + type=float, + help="Test fraction for train/test splitting.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + + args = parser.parse_args() + return args + +This defines various options for client data partitioning. +Besides, clients also have a option to conduct evaluation on centralised test set by setting :code:`--centralised-eval`. + +Example commands +~~~~~~~~ + +To run a centralised evaluated experiment on 5 clients with exponential distribution for 50 rounds, +we first start the server as below: + +.. code-block:: shell + + $ python3 server.py --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval + +Then, on each client terminal, we start the clients: + +.. code-block:: shell + + $ python3 clients.py --num-partitions=5 --partitioner-type=exponential --node-id=NODE_ID + +The full `source code `_ for this comprehensive example can be found in :code:`examples/xgboost-comprehensive`. diff --git a/examples/mt-pytorch-callable/README.md b/examples/mt-pytorch-callable/README.md new file mode 100644 index 000000000000..65ef000c26f2 --- /dev/null +++ b/examples/mt-pytorch-callable/README.md @@ -0,0 +1,49 @@ +# Deploy ๐Ÿงช + +๐Ÿงช = this page covers experimental features that might change in future versions of Flower + +This how-to guide describes the deployment of a long-running Flower server. + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +โ””โ”€โ”€ client.py +โ”œโ”€โ”€ driver.py +โ”œโ”€โ”€ requirements.txt +``` + +## Install dependencies + +```bash +pip install -r requirements.txt +``` + +## Start the long-running Flower server + +```bash +flower-server --insecure +``` + +## Start the long-running Flower client + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client --callable client:flower +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client --callable client:flower +``` + +## Start the Driver script + +```bash +python driver.py +``` diff --git a/examples/mt-pytorch-callable/client.py b/examples/mt-pytorch-callable/client.py new file mode 100644 index 000000000000..6f9747784ae0 --- /dev/null +++ b/examples/mt-pytorch-callable/client.py @@ -0,0 +1,123 @@ +import warnings +from collections import OrderedDict + +import flwr as fl +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor +from tqdm import tqdm + + +# ############################################################################# +# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader +# ############################################################################# + +warnings.filterwarnings("ignore", category=UserWarning) +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def train(net, trainloader, epochs): + """Train the model on the training set.""" + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + for _ in range(epochs): + for images, labels in tqdm(trainloader): + optimizer.zero_grad() + criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() + optimizer.step() + + +def test(net, testloader): + """Validate the model on the test set.""" + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in tqdm(testloader): + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +# ############################################################################# +# 2. Federation of the pipeline with Flower +# ############################################################################# + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + def set_parameters(self, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + def fit(self, parameters, config): + self.set_parameters(parameters) + train(net, trainloader, epochs=1) + return self.get_parameters(config={}), len(trainloader.dataset), {} + + def evaluate(self, parameters, config): + self.set_parameters(parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + """.""" + return FlowerClient().to_client() + + +# To run this: `flower-client --callable client:flower` +flower = fl.flower.Flower( + client_fn=client_fn, +) + + +if __name__ == "__main__": + # Start Flower client + fl.client.start_client( + server_address="0.0.0.0:9092", + client=FlowerClient().to_client(), + transport="grpc-rere", + ) diff --git a/examples/mt-pytorch-callable/driver.py b/examples/mt-pytorch-callable/driver.py new file mode 100644 index 000000000000..1248672b6813 --- /dev/null +++ b/examples/mt-pytorch-callable/driver.py @@ -0,0 +1,25 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower driver +fl.driver.start_driver( + server_address="0.0.0.0:9091", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/mt-pytorch-callable/pyproject.toml b/examples/mt-pytorch-callable/pyproject.toml new file mode 100644 index 000000000000..0d1a91836006 --- /dev/null +++ b/examples/mt-pytorch-callable/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "quickstart-pytorch" +version = "0.1.0" +description = "PyTorch Federated Learning Quickstart with Flower" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } +torch = "1.13.1" +torchvision = "0.14.1" +tqdm = "4.65.0" diff --git a/examples/mt-pytorch-callable/requirements.txt b/examples/mt-pytorch-callable/requirements.txt new file mode 100644 index 000000000000..797ca6db6244 --- /dev/null +++ b/examples/mt-pytorch-callable/requirements.txt @@ -0,0 +1,4 @@ +flwr>=1.0, <2.0 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.65.0 diff --git a/examples/mt-pytorch-callable/run.sh b/examples/mt-pytorch-callable/run.sh new file mode 100755 index 000000000000..d2bf34f834b1 --- /dev/null +++ b/examples/mt-pytorch-callable/run.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +# Download the CIFAR-10 dataset +python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python client.py & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/mt-pytorch-callable/server.py b/examples/mt-pytorch-callable/server.py new file mode 100644 index 000000000000..fe691a88aba0 --- /dev/null +++ b/examples/mt-pytorch-callable/server.py @@ -0,0 +1,25 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/quickstart-mlcube/dev/requirements.txt b/examples/quickstart-mlcube/dev/requirements.txt index c39a5fa73f81..5cfa618878d2 100644 --- a/examples/quickstart-mlcube/dev/requirements.txt +++ b/examples/quickstart-mlcube/dev/requirements.txt @@ -1,4 +1,4 @@ -PyYAML==5.3 +PyYAML==5.4 tensorflow==2.14.0 tensorflow-estimator==2.14.0 requests[security] diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index 3b31e23cb321..da002a10d301 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -1,7 +1,14 @@ -# Flower Example using XGBoost +# Flower Example using XGBoost (Comprehensive) -This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. -Tree-based with bagging method is used for aggregation on the server. +This example demonstrates a comprehensive federated learning setup using Flower with XGBoost. +We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset to perform a binary classification task. +It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) example in the following ways: + +- Arguments parsers of server and clients for hyperparameters selection. +- Customised FL settings. +- Customised number of partitions. +- Customised partitioner type (uniform, linear, square, exponential). +- Centralised/distributed evaluation. ## Project Setup @@ -18,7 +25,8 @@ This will create a new directory called `xgboost-comprehensive` containing the f -- server.py <- Defines the server-side logic -- client.py <- Defines the client-side logic -- dataset.py <- Defines the functions of data loading and partitioning --- utils.py <- Defines the arguments parser for clients and server. +-- utils.py <- Defines the arguments parser for clients and server +-- run.sh <- Commands to run experiments -- pyproject.toml <- Example dependencies (if you use Poetry) -- requirements.txt <- Example dependencies ``` @@ -52,36 +60,35 @@ pip install -r requirements.txt ## Run Federated Learning with XGBoost and Flower -Afterwards you are ready to start the Flower server as well as the clients. -You can simply start the server in a terminal as follows: +The included `run.sh` will start the Flower server (using `server.py`) with centralised evaluation, +sleep for 15 seconds to ensure that the server is up, +and then start 5 Flower clients (using `client.py`) with a small subset of the data from exponential partition distribution. +You can simply start everything in a terminal as follows: ```shell -python3 server.py +poetry run ./run.sh ``` -Now you are ready to start the Flower clients which will participate in the learning. -To do so simply open two more terminal windows and run the following commands. - -Start client 1 in the first terminal: +The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. +If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, +which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. +This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). +If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) +to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -```shell -python3 client.py --node-id=0 -``` - -Start client 2 in the second terminal: - -```shell -python3 client.py --node-id=1 -``` +You can also manually run `poetry run python3 server.py --pool-size=N --num-clients-per-round=N` +and `poetry run python3 client.py --node-id=NODE_ID --num-partitions=N` for as many clients as you want, +but you have to make sure that each command is run in a different terminal window (or a different computer on the network). -You will see that XGBoost is starting a federated training. +In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). +Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) +and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. -Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows: +### Expected Experimental Results -```shell -bash run.sh -``` +![](_static/xgboost_flower_auc.png) -Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). -Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) -and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. +The figure above shows the centralised tested AUC performance over FL rounds on 4 experimental settings. +One can see that all settings obtain stable performance boost over FL rounds (especially noticeable at the start of training). +As expected, uniform client distribution shows higher AUC values (beyond 83% at the end) than square/exponential setup. +Feel free to explore more interesting experiments by yourself! diff --git a/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png b/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png new file mode 100644 index 000000000000..e6a4bfb83250 Binary files /dev/null and b/examples/xgboost-comprehensive/_static/xgboost_flower_auc.png differ diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index 5aba30266b5a..a37edac32648 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -23,7 +23,7 @@ transform_dataset_to_dmatrix, resplit, ) -from utils import client_args_parser +from utils import client_args_parser, BST_PARAMS warnings.filterwarnings("ignore", category=UserWarning) @@ -43,12 +43,15 @@ partitioner_type=partitioner_type, num_partitions=num_partitions ) fds = FederatedDataset( - dataset="jxie/higgs", partitioners={"train": partitioner}, resplitter=resplit + dataset="jxie/higgs", + partitioners={"train": partitioner}, + resplitter=resplit, ) -# Let's use the first partition as an example +# Load the partition for this `node_id` +log(INFO, "Loading partition...") node_id = args.node_id -partition = fds.load_partition(idx=node_id, split="train") +partition = fds.load_partition(node_id=node_id, split="train") partition.set_format("numpy") if args.centralised_eval: @@ -67,26 +70,18 @@ ) # Reformat data to DMatrix for xgboost +log(INFO, "Reformatting data...") train_dmatrix = transform_dataset_to_dmatrix(train_data) valid_dmatrix = transform_dataset_to_dmatrix(valid_data) # Hyper-parameters for xgboost training num_local_round = 1 -params = { - "objective": "binary:logistic", - "eta": 0.1, # Learning rate - "max_depth": 8, - "eval_metric": "auc", - "nthread": 16, - "num_parallel_tree": 1, - "subsample": 1, - "tree_method": "hist", -} +params = BST_PARAMS # Define Flower client -class FlowerClient(fl.client.Client): +class XgbClient(fl.client.Client): def __init__(self): self.bst = None self.config = None @@ -171,4 +166,4 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient()) +fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) diff --git a/examples/xgboost-comprehensive/pyproject.toml b/examples/xgboost-comprehensive/pyproject.toml index 5414b5122154..bbfbb4134b8d 100644 --- a/examples/xgboost-comprehensive/pyproject.toml +++ b/examples/xgboost-comprehensive/pyproject.toml @@ -10,6 +10,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.0,<2.0" +flwr-nightly = ">=1.0,<2.0" flwr-datasets = ">=0.0.2,<1.0.0" xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-comprehensive/requirements.txt b/examples/xgboost-comprehensive/requirements.txt index c6b9c1a67894..c37ac2b6ad6d 100644 --- a/examples/xgboost-comprehensive/requirements.txt +++ b/examples/xgboost-comprehensive/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.0, <2.0 +flwr-nightly>=1.0, <2.0 flwr-datasets>=0.0.2, <1.0.0 xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-comprehensive/run.sh b/examples/xgboost-comprehensive/run.sh index 7cf65fa4d52d..7920f6bf5e55 100755 --- a/examples/xgboost-comprehensive/run.sh +++ b/examples/xgboost-comprehensive/run.sh @@ -3,12 +3,12 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ echo "Starting server" -python server.py & +python3 server.py --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval & sleep 15 # Sleep for 15s to give the server enough time to start -for i in `seq 0 1`; do +for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i & + python3 client.py --node-id=$i --num-partitions=5 --partitioner-type=exponential & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index e4c597ee17eb..3da7e8d9865c 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -8,7 +8,7 @@ from flwr_datasets import FederatedDataset from flwr.server.strategy import FedXgbBagging -from utils import server_args_parser +from utils import server_args_parser, BST_PARAMS from dataset import resplit, transform_dataset_to_dmatrix @@ -30,16 +30,7 @@ test_dmatrix = transform_dataset_to_dmatrix(test_set) # Hyper-parameters used for initialisation -params = { - "objective": "binary:logistic", - "eta": 0.1, # Learning rate - "max_depth": 8, - "eval_metric": "auc", - "nthread": 16, - "num_parallel_tree": 1, - "subsample": 1, - "tree_method": "hist", -} +params = BST_PARAMS def eval_config(rnd: int) -> Dict[str, str]: diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py index 51c1a1b9604d..000def370752 100644 --- a/examples/xgboost-comprehensive/utils.py +++ b/examples/xgboost-comprehensive/utils.py @@ -1,6 +1,18 @@ import argparse +BST_PARAMS = { + "objective": "binary:logistic", + "eta": 0.1, # Learning rate + "max_depth": 8, + "eval_metric": "auc", + "nthread": 16, + "num_parallel_tree": 1, + "subsample": 1, + "tree_method": "hist", +} + + def client_args_parser(): """Parse arguments to define experimental settings on client side.""" parser = argparse.ArgumentParser() diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md new file mode 100644 index 000000000000..5174c236c668 --- /dev/null +++ b/examples/xgboost-quickstart/README.md @@ -0,0 +1,88 @@ +# Flower Example using XGBoost + +This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. +We use [HIGGS](https://archive.ics.uci.edu/dataset/280/higgs) dataset for this example to perform a binary classification task. +Tree-based with bagging method is used for aggregation on the server. + +This project provides a minimal code example to enable you to get stated quickly. For a more comprehensive code example, take a look at [xgboost-comprehensive](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive). + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/xgboost-quickstart . && rm -rf flower && cd xgboost-quickstart +``` + +This will create a new directory called `xgboost-quickstart` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- server.py <- Defines the server-side logic +-- client.py <- Defines the client-side logic +-- run.sh <- Commands to run experiments +-- pyproject.toml <- Example dependencies (if you use Poetry) +-- requirements.txt <- Example dependencies +``` + +### Installing Dependencies + +Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with XGBoost and Flower + +Afterwards you are ready to start the Flower server as well as the clients. +You can simply start the server in a terminal as follows: + +```shell +python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. +To do so simply open two more terminal windows and run the following commands. + +Start client 1 in the first terminal: + +```shell +python3 client.py --node-id=0 +``` + +Start client 2 in the second terminal: + +```shell +python3 client.py --node-id=1 +``` + +You will see that XGBoost is starting a federated training. + +Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows: + +```shell +poetry run ./run.sh +``` + +Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) +and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py new file mode 100644 index 000000000000..b5eab59ba14d --- /dev/null +++ b/examples/xgboost-quickstart/client.py @@ -0,0 +1,176 @@ +import argparse +import warnings +from typing import Union +from logging import INFO +from datasets import Dataset, DatasetDict +import xgboost as xgb + +import flwr as fl +from flwr_datasets import FederatedDataset +from flwr.common.logger import log +from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + Parameters, + Status, +) +from flwr_datasets.partitioner import IidPartitioner + + +warnings.filterwarnings("ignore", category=UserWarning) + +# Define arguments parser for the client/node ID. +parser = argparse.ArgumentParser() +parser.add_argument( + "--node-id", + default=0, + type=int, + help="Node ID used for the current client.", +) +args = parser.parse_args() + + +# Define data partitioning related functions +def train_test_split(partition: Dataset, test_fraction: float, seed: int): + """Split the data into train and validation set given split rate.""" + train_test = partition.train_test_split(test_size=test_fraction, seed=seed) + partition_train = train_test["train"] + partition_test = train_test["test"] + + num_train = len(partition_train) + num_test = len(partition_test) + + return partition_train, partition_test, num_train, num_test + + +def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: + """Transform dataset to DMatrix format for xgboost.""" + x = data["inputs"] + y = data["label"] + new_data = xgb.DMatrix(x, label=y) + return new_data + + +# Load (HIGGS) dataset and conduct partitioning +# We use a small subset (num_partitions=30) of the dataset for demonstration to speed up the data loading process. +partitioner = IidPartitioner(num_partitions=30) +fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) + +# Load the partition for this `node_id` +log(INFO, "Loading partition...") +partition = fds.load_partition(node_id=args.node_id, split="train") +partition.set_format("numpy") + +# Train/test splitting +train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=0.2, seed=42 +) + +# Reformat data to DMatrix for xgboost +log(INFO, "Reformatting data...") +train_dmatrix = transform_dataset_to_dmatrix(train_data) +valid_dmatrix = transform_dataset_to_dmatrix(valid_data) + +# Hyper-parameters for xgboost training +num_local_round = 1 +params = { + "objective": "binary:logistic", + "eta": 0.1, # Learning rate + "max_depth": 8, + "eval_metric": "auc", + "nthread": 16, + "num_parallel_tree": 1, + "subsample": 1, + "tree_method": "hist", +} + + +# Define Flower client +class XgbClient(fl.client.Client): + def __init__(self): + self.bst = None + self.config = None + + def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: + _ = (self, ins) + return GetParametersRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[]), + ) + + def _local_boost(self): + # Update trees based on local training data. + for i in range(num_local_round): + self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + + # Extract the last N=num_local_round trees for sever aggregation + bst = self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + + return bst + + def fit(self, ins: FitIns) -> FitRes: + if not self.bst: + # First round local training + log(INFO, "Start training at round 1") + bst = xgb.train( + params, + train_dmatrix, + num_boost_round=num_local_round, + evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + ) + self.config = bst.save_config() + self.bst = bst + else: + for item in ins.parameters.tensors: + global_model = bytearray(item) + + # Load global model into booster + self.bst.load_model(global_model) + self.bst.load_config(self.config) + + bst = self._local_boost() + + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=num_train, + metrics={}, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + eval_results = self.bst.eval_set( + evals=[(valid_dmatrix, "valid")], + iteration=self.bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=num_val, + metrics={"AUC": auc}, + ) + + +# Start Flower client +fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) diff --git a/examples/xgboost-quickstart/pyproject.toml b/examples/xgboost-quickstart/pyproject.toml new file mode 100644 index 000000000000..d82535311e58 --- /dev/null +++ b/examples/xgboost-quickstart/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "xgboost-quickstart" +version = "0.1.0" +description = "Federated XGBoost with Flower (quickstart)" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr-nightly = ">=1.0,<2.0" +flwr-datasets = ">=0.0.1,<1.0.0" +xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-quickstart/requirements.txt b/examples/xgboost-quickstart/requirements.txt new file mode 100644 index 000000000000..aefd74097582 --- /dev/null +++ b/examples/xgboost-quickstart/requirements.txt @@ -0,0 +1,3 @@ +flwr-nightly>=1.0, <2.0 +flwr-datasets>=0.0.1, <1.0.0 +xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-quickstart/run.sh b/examples/xgboost-quickstart/run.sh new file mode 100755 index 000000000000..6287145bfb5f --- /dev/null +++ b/examples/xgboost-quickstart/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python server.py & +sleep 5 # Sleep for 5s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python3 client.py --node-id=$i & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/xgboost-quickstart/server.py b/examples/xgboost-quickstart/server.py new file mode 100644 index 000000000000..b45a375ce94f --- /dev/null +++ b/examples/xgboost-quickstart/server.py @@ -0,0 +1,37 @@ +import flwr as fl +from flwr.server.strategy import FedXgbBagging + + +# FL experimental settings +pool_size = 2 +num_rounds = 5 +num_clients_per_round = 2 +num_evaluate_clients = 2 + + +def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + + +# Define strategy +strategy = FedXgbBagging( + fraction_fit=(float(num_clients_per_round) / pool_size), + min_fit_clients=num_clients_per_round, + min_available_clients=pool_size, + min_evaluate_clients=num_evaluate_clients, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, +) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=num_rounds), + strategy=strategy, +) diff --git a/src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt b/src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt index a88af0e28974..6db7ecd36987 100644 --- a/src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt +++ b/src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt @@ -23,11 +23,11 @@ typealias Properties = Map * The `Code` class defines client status codes used in the application. */ enum class Code(val value: Int) { - OK(1), - GET_PROPERTIES_NOT_IMPLEMENTED(2), - GET_PARAMETERS_NOT_IMPLEMENTED(3), - FIT_NOT_IMPLEMENTED(4), - EVALUATE_NOT_IMPLEMENTED(5); + OK(0), + GET_PROPERTIES_NOT_IMPLEMENTED(1), + GET_PARAMETERS_NOT_IMPLEMENTED(2), + FIT_NOT_IMPLEMENTED(3), + EVALUATE_NOT_IMPLEMENTED(4); companion object { fun fromInt(value: Int): Code = values().first { it.value == value } diff --git a/src/py/flwr/__init__.py b/src/py/flwr/__init__.py index d3cbf00747a4..e05799280339 100644 --- a/src/py/flwr/__init__.py +++ b/src/py/flwr/__init__.py @@ -17,12 +17,13 @@ from flwr.common.version import package_version as _package_version -from . import client, common, driver, server, simulation +from . import client, common, driver, flower, server, simulation __all__ = [ "client", "common", "driver", + "flower", "server", "simulation", ] diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 0013b74c631c..b39dbbfc33c0 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -22,6 +22,7 @@ from typing import Callable, ContextManager, Optional, Tuple, Union from flwr.client.client import Client +from flwr.client.flower import Bwd, Flower, Fwd from flwr.client.typing import ClientFn from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address @@ -32,13 +33,15 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) -from flwr.common.logger import log +from flwr.common.logger import log, warn_experimental_feature from flwr.proto.task_pb2 import TaskIns, TaskRes +from .flower import load_callable from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response -from .message_handler.message_handler import handle, handle_control_message +from .message_handler.message_handler import handle_control_message from .numpy_client import NumPyClient +from .workload_state import WorkloadState def run_client() -> None: @@ -48,6 +51,22 @@ def run_client() -> None: args = _parse_args_client().parse_args() print(args.server) + print(args.callable_dir) + print(args.callable) + + callable_dir = args.callable_dir + if callable_dir is not None: + sys.path.insert(0, callable_dir) + + def _load() -> Flower: + flower: Flower = load_callable(args.callable) + return flower + + return start_client( + server_address=args.server, + load_callable_fn=_load, + transport="grpc-rere", # Only + ) def _parse_args_client() -> argparse.ArgumentParser: @@ -58,8 +77,18 @@ def _parse_args_client() -> argparse.ArgumentParser: parser.add_argument( "--server", - help="Server address", default="0.0.0.0:9092", + help="Server address", + ) + parser.add_argument( + "--callable", + help="For example: `client:flower` or `project.package.module:wrapper.flower`", + ) + parser.add_argument( + "--callable-dir", + default="", + help="Add specified directory to the PYTHONPATH and load callable from there." + " Default: current working directory.", ) return parser @@ -84,6 +113,7 @@ def _check_actionable_client( def start_client( *, server_address: str, + load_callable_fn: Optional[Callable[[], Flower]] = None, client_fn: Optional[ClientFn] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, @@ -98,6 +128,8 @@ def start_client( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. + load_callable_fn : Optional[Callable[[], Flower]] (default: None) + ... client_fn : Optional[ClientFn] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] @@ -146,20 +178,31 @@ class `flwr.client.Client` (default: None) """ event(EventType.START_CLIENT_ENTER) - _check_actionable_client(client, client_fn) + if load_callable_fn is None: + _check_actionable_client(client, client_fn) - if client_fn is None: - # Wrap `Client` instance in `client_fn` - def single_client_factory( - cid: str, # pylint: disable=unused-argument - ) -> Client: - if client is None: # Added this to keep mypy happy - raise Exception( - "Both `client_fn` and `client` are `None`, but one is required" - ) - return client # Always return the same instance + if client_fn is None: + # Wrap `Client` instance in `client_fn` + def single_client_factory( + cid: str, # pylint: disable=unused-argument + ) -> Client: + if client is None: # Added this to keep mypy happy + raise Exception( + "Both `client_fn` and `client` are `None`, but one is required" + ) + return client # Always return the same instance + + client_fn = single_client_factory + + def _load_app() -> Flower: + return Flower(client_fn=client_fn) - client_fn = single_client_factory + load_callable_fn = _load_app + else: + warn_experimental_feature("`load_callable_fn`") + + # At this point, only `load_callable_fn` should be used + # Both `client` and `client_fn` must not be used directly # Initialize connection context manager connection, address = _init_connection(transport, server_address) @@ -190,11 +233,18 @@ def single_client_factory( send(task_res) break + # Load app + app: Flower = load_callable_fn() + # Handle task message - task_res = handle(client_fn, task_ins) + fwd_msg: Fwd = Fwd( + task_ins=task_ins, + state=WorkloadState(state={}), + ) + bwd_msg: Bwd = app(fwd=fwd_msg) # Send - send(task_res) + send(bwd_msg.task_res) # Unregister node if delete_node is not None: diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py new file mode 100644 index 000000000000..9eeb41887e24 --- /dev/null +++ b/src/py/flwr/client/flower.py @@ -0,0 +1,138 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower callable.""" + + +import importlib +from dataclasses import dataclass +from typing import Callable, cast + +from flwr.client.message_handler.message_handler import handle +from flwr.client.typing import ClientFn +from flwr.client.workload_state import WorkloadState +from flwr.proto.task_pb2 import TaskIns, TaskRes + + +@dataclass +class Fwd: + """.""" + + task_ins: TaskIns + state: WorkloadState + + +@dataclass +class Bwd: + """.""" + + task_res: TaskRes + state: WorkloadState + + +FlowerCallable = Callable[[Fwd], Bwd] + + +class Flower: + """Flower callable. + + Examples + -------- + Assuming a typical client implementation in `FlowerClient`, you can wrap it in a + Flower callable as follows: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid): + >>> return FlowerClient().to_client() + >>> + >>> flower = Flower(client_fn) + + If the above code is in a Python module called `client`, it can be started as + follows: + + >>> flower-client --callable client:flower + + In this `client:flower` example, `client` refers to the Python module in which the + previous code lives in. `flower` refers to the global attribute `flower` that points + to an object of type `Flower` (a Flower callable). + """ + + def __init__( + self, + client_fn: ClientFn, # Only for backward compatibility + ) -> None: + self.client_fn = client_fn + + def __call__(self, fwd: Fwd) -> Bwd: + """.""" + # Execute the task + task_res = handle( + client_fn=self.client_fn, + task_ins=fwd.task_ins, + ) + return Bwd( + task_res=task_res, + state=WorkloadState(state={}), + ) + + +class LoadCallableError(Exception): + """.""" + + +def load_callable(module_attribute_str: str) -> Flower: + """Load the `Flower` object specified in a module attribute string. + + The module/attribute string should have the form :. Valid + examples include `client:flower` and `project.package.module:wrapper.flower`. It + must refer to a module on the PYTHONPATH, the module needs to have the specified + attribute, and the attribute must be of type `Flower`. + """ + module_str, _, attributes_str = module_attribute_str.partition(":") + if not module_str: + raise LoadCallableError( + f"Missing module in {module_attribute_str}", + ) from None + if not attributes_str: + raise LoadCallableError( + f"Missing attribute in {module_attribute_str}", + ) from None + + # Load module + try: + module = importlib.import_module(module_str) + except ModuleNotFoundError: + raise LoadCallableError( + f"Unable to load module {module_str}", + ) from None + + # Recursively load attribute + attribute = module + try: + for attribute_str in attributes_str.split("."): + attribute = getattr(attribute, attribute_str) + except AttributeError: + raise LoadCallableError( + f"Unable to load attribute {attributes_str} from module {module_str}", + ) from None + + # Check type + if not isinstance(attribute, Flower): + raise LoadCallableError( + f"Attribute {attributes_str} is not of type {Flower}", + ) from None + + return cast(Flower, attribute) diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index b69228826e13..424e413dc484 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -16,7 +16,7 @@ from contextlib import contextmanager -from logging import DEBUG, ERROR, WARN +from logging import DEBUG, ERROR from pathlib import Path from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast @@ -28,7 +28,7 @@ ) from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.grpc import create_channel -from flwr.common.logger import log +from flwr.common.logger import log, warn_experimental_feature from flwr.proto.fleet_pb2 import ( CreateNodeRequest, DeleteNodeRequest, @@ -88,6 +88,8 @@ def grpc_request_response( create_node : Optional[Callable] delete_node : Optional[Callable] """ + warn_experimental_feature("`grpc-rere`") + if isinstance(root_certificates, str): root_certificates = Path(root_certificates).read_bytes() @@ -99,14 +101,6 @@ def grpc_request_response( channel.subscribe(on_channel_state_change) stub = FleetStub(channel) - log( - WARN, - """ - EXPERIMENTAL: `grpc-rere` is an experimental transport layer, it might change - considerably in future versions of Flower - """, - ) - # Necessary state to link TaskRes to TaskIns state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None} diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index e543d6565878..29d1562a86d3 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -16,7 +16,7 @@ import logging -from logging import LogRecord +from logging import WARN, LogRecord from logging.handlers import HTTPHandler from typing import Any, Dict, Optional, Tuple @@ -97,3 +97,17 @@ def configure( logger = logging.getLogger(LOGGER_NAME) # pylint: disable=invalid-name log = logger.log # pylint: disable=invalid-name + + +def warn_experimental_feature(name: str) -> None: + """Warn the user when they use an experimental feature.""" + log( + WARN, + """ + EXPERIMENTAL FEATURE: %s + + This is an experimental feature. It could change significantly or be removed + entirely in future versions of Flower. + """, + name, + ) diff --git a/src/py/flwr/flower/__init__.py b/src/py/flwr/flower/__init__.py new file mode 100644 index 000000000000..090c78062d02 --- /dev/null +++ b/src/py/flwr/flower/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2020 Adap GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower callable package.""" + + +from flwr.client.flower import Bwd as Bwd +from flwr.client.flower import Flower as Flower +from flwr.client.flower import Fwd as Fwd + +__all__ = [ + "Flower", + "Fwd", + "Bwd", +] diff --git a/src/py/flwr/server/strategy/fedxgb_nn_avg.py b/src/py/flwr/server/strategy/fedxgb_nn_avg.py index 020e0ef71267..f300633d0d9f 100644 --- a/src/py/flwr/server/strategy/fedxgb_nn_avg.py +++ b/src/py/flwr/server/strategy/fedxgb_nn_avg.py @@ -17,7 +17,7 @@ Strategy in the horizontal setting based on building Neural Network and averaging on prediction outcomes. -Paper: Coming +Paper: arxiv.org/abs/2304.07537 """ @@ -35,6 +35,13 @@ class FedXgbNnAvg(FedAvg): """Configurable FedXgbNnAvg strategy implementation.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Federated XGBoost [Ma et al., 2023] strategy. + + Implementation based on https://arxiv.org/abs/2304.07537. + """ + super().__init__(*args, **kwargs) + def __repr__(self) -> str: """Compute a string representation of the strategy.""" rep = f"FedXgbNnAvg(accept_failures={self.accept_failures})"