Skip to content

Commit

Permalink
Merge branch 'main' into inplace_fedavg
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Jan 5, 2024
2 parents a1af10e + 2b4297d commit 38e6d55
Show file tree
Hide file tree
Showing 52 changed files with 424 additions and 443 deletions.
2 changes: 1 addition & 1 deletion baselines/hfedxgboost/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dataset: [a9a, cod-rna, ijcnn1, space_ga, cpusmall, YearPredictionMSD]
**Paper:** [arxiv.org/abs/2304.07537](https://arxiv.org/abs/2304.07537)

**Authors:** Chenyang Ma, Xinchi Qiu, Daniel J. Beutel, Nicholas D. Laneearly_stop_patience_rounds: 100
**Authors:** Chenyang Ma, Xinchi Qiu, Daniel J. Beutel, Nicholas D. Lane

**Abstract:** The privacy-sensitive nature of decentralized datasets and the robustness of eXtreme Gradient Boosting (XGBoost) on tabular data raise the need to train XGBoost in the context of federated learning (FL). Existing works on federated XGBoost in the horizontal setting rely on the sharing of gradients, which induce per-node level communication frequency and serious privacy concerns. To alleviate these problems, we develop an innovative framework for horizontal federated XGBoost which does not depend on the sharing of gradients and simultaneously boosts privacy and communication efficiency by making the learning rates of the aggregated tree ensembles are learnable. We conduct extensive evaluations on various classification and regression datasets, showing our approach achieve performance comparable to the state-of-the-art method and effectively improves communication efficiency by lowering both communication rounds and communication overhead by factors ranging from 25x to 700x.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@
" min_available_clients=10, # Wait until all 10 clients are available\n",
")\n",
"\n",
"# Specify the resources each of your clients need. By default, each \n",
"# Specify the resources each of your clients need. By default, each\n",
"# client will be allocated 1x CPU and 0x CPUs\n",
"client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n",
"if DEVICE.type == \"cuda\":\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main(cfg: DictConfig):
save_path = HydraConfig.get().runtime.output_dir

## 2. Prepare your dataset
# When simulating FL workloads we have a lot of freedom on how the FL clients behave,
# When simulating FL runs we have a lot of freedom on how the FL clients behave,
# what data they have, how much data, etc. This is not possible in real FL settings.
# In simulation you'd often encounter two types of dataset:
# * naturally partitioned, that come pre-partitioned by user id (e.g. FEMNIST,
Expand Down Expand Up @@ -91,7 +91,7 @@ def main(cfg: DictConfig):
"num_gpus": 0.0,
}, # (optional) controls the degree of parallelism of your simulation.
# Lower resources per client allow for more clients to run concurrently
# (but need to be set taking into account the compute/memory footprint of your workload)
# (but need to be set taking into account the compute/memory footprint of your run)
# `num_cpus` is an absolute number (integer) indicating the number of threads a client should be allocated
# `num_gpus` is a ratio indicating the portion of gpu memory that a client needs.
)
Expand Down
12 changes: 6 additions & 6 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload(
req=driver_pb2.CreateWorkloadRequest()
create_run_res: driver_pb2.CreateRunResponse = driver.create_run(
req=driver_pb2.CreateRunRequest()
)
# -------------------------------------------------------------------------- Driver SDK

workload_id = create_workload_res.workload_id
print(f"Created workload id {workload_id}")
run_id = create_run_res.run_id
print(f"Created run id {run_id}")

history = History()
for server_round in range(num_rounds):
Expand Down Expand Up @@ -93,7 +93,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# loop and wait until enough client nodes are available.
while True:
# Get a list of node ID's from the server
get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id)
get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id)

# ---------------------------------------------------------------------- Driver SDK
get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes(
Expand Down Expand Up @@ -125,7 +125,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
new_task_ins = task_pb2.TaskIns(
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
workload_id=workload_id,
run_id=run_id,
task=task_pb2.Task(
producer=node_pb2.Node(
node_id=0,
Expand Down
14 changes: 7 additions & 7 deletions examples/pytorch-from-centralized-to-federated/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def apply_transforms(batch):


def train(
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device, # pylint: disable=no-member
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device, # pylint: disable=no-member
) -> None:
"""Train the network."""
# Define loss and optimizer
Expand Down Expand Up @@ -110,9 +110,9 @@ def train(


def test(
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device, # pylint: disable=no-member
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device, # pylint: disable=no-member
) -> Tuple[float, float]:
"""Validate the network on the entire test set."""
# Define loss and metrics
Expand Down
12 changes: 6 additions & 6 deletions examples/pytorch-from-centralized-to-federated/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class CifarClient(fl.client.NumPyClient):
"""Flower client implementing CIFAR-10 image classification using PyTorch."""

def __init__(
self,
model: cifar.Net,
trainloader: DataLoader,
testloader: DataLoader,
self,
model: cifar.Net,
trainloader: DataLoader,
testloader: DataLoader,
) -> None:
self.model = model
self.trainloader = trainloader
Expand Down Expand Up @@ -61,15 +61,15 @@ def set_parameters(self, parameters: List[np.ndarray]) -> None:
self.model.load_state_dict(state_dict, strict=True)

def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
self.set_parameters(parameters)
cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
return self.get_parameters(config={}), len(self.trainloader.dataset), {}

def evaluate(
self, parameters: List[np.ndarray], config: Dict[str, str]
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate model on local test dataset, return result
self.set_parameters(parameters)
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart-pytorch-lightning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

disable_progress_bar()


class FlowerClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
Expand Down Expand Up @@ -55,7 +56,6 @@ def _set_parameters(model, parameters):


def main() -> None:

parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
Expand Down
16 changes: 10 additions & 6 deletions examples/quickstart-pytorch-lightning/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,20 @@ def load_data(partition):
# 60 % for the federated train and 20 % for the federated validation (both in fit)
partition_train_valid = partition_full["train"].train_test_split(train_size=0.75)
trainloader = DataLoader(
partition_train_valid["train"], batch_size=32,
shuffle=True, collate_fn=collate_fn, num_workers=1
partition_train_valid["train"],
batch_size=32,
shuffle=True,
collate_fn=collate_fn,
num_workers=1,
)
valloader = DataLoader(
partition_train_valid["test"], batch_size=32,
collate_fn=collate_fn, num_workers=1
partition_train_valid["test"],
batch_size=32,
collate_fn=collate_fn,
num_workers=1,
)
testloader = DataLoader(
partition_full["test"], batch_size=32,
collate_fn=collate_fn, num_workers=1
partition_full["test"], batch_size=32, collate_fn=collate_fn, num_workers=1
)
return trainloader, valloader, testloader

Expand Down
4 changes: 3 additions & 1 deletion examples/quickstart-sklearn-tabular/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ def evaluate(self, parameters, config): # type: ignore
return loss, len(X_test), {"test_accuracy": accuracy}

# Start Flower client
fl.client.start_client(server_address="0.0.0.0:8080", client=IrisClient().to_client())
fl.client.start_client(
server_address="0.0.0.0:8080", client=IrisClient().to_client()
)
13 changes: 7 additions & 6 deletions examples/secaggplus-mt/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task:
task_pb2.TaskIns(
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
workload_id=workload_id,
run_id=run_id,
run_id=run_id,
task=merge(
task,
task_pb2.Task(
Expand Down Expand Up @@ -84,13 +85,13 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload(
req=driver_pb2.CreateWorkloadRequest()
create_run_res: driver_pb2.CreateRunResponse = driver.create_run(
req=driver_pb2.CreateRunRequest()
)
# -------------------------------------------------------------------------- Driver SDK

workload_id = create_workload_res.workload_id
print(f"Created workload id {workload_id}")
run_id = create_run_res.run_id
print(f"Created run id {run_id}")

history = History()
for server_round in range(num_rounds):
Expand Down Expand Up @@ -119,7 +120,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# loop and wait until enough client nodes are available.
while True:
# Get a list of node ID's from the server
get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id)
get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id)

# ---------------------------------------------------------------------- Driver SDK
get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes(
Expand Down
12 changes: 6 additions & 6 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import "flwr/proto/node.proto";
import "flwr/proto/task.proto";

service Driver {
// Request workload_id
rpc CreateWorkload(CreateWorkloadRequest) returns (CreateWorkloadResponse) {}
// Request run_id
rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {}

// Return a set of nodes
rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {}
Expand All @@ -34,12 +34,12 @@ service Driver {
rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {}
}

// CreateWorkload
message CreateWorkloadRequest {}
message CreateWorkloadResponse { sint64 workload_id = 1; }
// CreateRun
message CreateRunRequest {}
message CreateRunResponse { sint64 run_id = 1; }

// GetNodes messages
message GetNodesRequest { sint64 workload_id = 1; }
message GetNodesRequest { sint64 run_id = 1; }
message GetNodesResponse { repeated Node nodes = 1; }

// PushTaskIns messages
Expand Down
4 changes: 2 additions & 2 deletions src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ message Task {
message TaskIns {
string task_id = 1;
string group_id = 2;
sint64 workload_id = 3;
sint64 run_id = 3;
Task task = 4;
}

message TaskRes {
string task_id = 1;
string group_id = 2;
sint64 workload_id = 3;
sint64 run_id = 3;
Task task = 4;
}

Expand Down
12 changes: 5 additions & 7 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,24 +349,22 @@ def _load_app() -> Flower:
break

# Register state
node_state.register_workloadstate(workload_id=task_ins.workload_id)
node_state.register_runstate(run_id=task_ins.run_id)

# Load app
app: Flower = load_flower_callable_fn()

# Handle task message
fwd_msg: Fwd = Fwd(
task_ins=task_ins,
state=node_state.retrieve_workloadstate(
workload_id=task_ins.workload_id
),
state=node_state.retrieve_runstate(run_id=task_ins.run_id),
)
bwd_msg: Bwd = app(fwd=fwd_msg)

# Update node state
node_state.update_workloadstate(
workload_id=bwd_msg.task_res.workload_id,
workload_state=bwd_msg.state,
node_state.update_runstate(
run_id=bwd_msg.task_res.run_id,
run_state=bwd_msg.state,
)

# Send
Expand Down
12 changes: 6 additions & 6 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABC

from flwr.client.workload_state import WorkloadState
from flwr.client.run_state import RunState
from flwr.common import (
Code,
EvaluateIns,
Expand All @@ -38,7 +38,7 @@
class Client(ABC):
"""Abstract base class for Flower clients."""

state: WorkloadState
state: RunState

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Expand Down Expand Up @@ -141,12 +141,12 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> WorkloadState:
"""Get the workload state from this client."""
def get_state(self) -> RunState:
"""Get the run state from this client."""
return self.state

def set_state(self, state: WorkloadState) -> None:
"""Apply a workload state to this client."""
def set_state(self, state: RunState) -> None:
"""Apply a run state to this client."""
self.state = state

def to_client(self) -> Client:
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def receive() -> TaskIns:
return TaskIns(
task_id=str(uuid.uuid4()),
group_id="",
workload_id=0,
run_id=0,
task=Task(
producer=Node(node_id=0, anonymous=True),
consumer=Node(node_id=0, anonymous=True),
Expand Down
Loading

0 comments on commit 38e6d55

Please sign in to comment.