diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/README.md b/python/examples/simulation/mpi_torch_hierarchical_fl/README.md new file mode 100644 index 0000000000..6138709194 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/README.md @@ -0,0 +1,18 @@ +# Install FedML and Prepare the Distributed Environment +``` +pip install fedml +``` + + +# Run the example + +## mpi hierarchical fl +``` +sh run_step_by_step_example.sh 5 config/mnist_lr/fedml_config.yaml +``` + +## mpi hierarchical fl based on some topology (e.g., 2d_torus, star, complete, isolated, balanced_tree and random) +``` +sh run_step_by_step_example.sh 5 config/mnist_lr/fedml_config.yaml +``` + diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/__init__.py b/python/examples/simulation/mpi_torch_hierarchical_fl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/batch_run.sh b/python/examples/simulation/mpi_torch_hierarchical_fl/batch_run.sh new file mode 100755 index 0000000000..fcbe10bde3 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/batch_run.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +GROUP_NUM=5 +GROUP_METHOD="hetero" +COMM_ROUND=62 #250 +GROUP_COMM_ROUND=4 # 1 +TOPO_NAME="star" +CONFIG_PATH=config/mnist_lr/fedml_config_topo.yaml + +group_alpha_list=(0.01 0.1 1.0) + +WORKER_NUM=$(($GROUP_NUM+1)) +hostname > mpi_host_file +mkdir -p batch_log +# we need to install yq (https://github.com/mikefarah/yq) +# wget https://github.com/mikefarah/yq/releases/latest/download/yq_linux_amd64 -O /usr/bin/yq && chmod +x /usr/bin/yq + +yq -i ".device_args.worker_num = ${WORKER_NUM}" $CONFIG_PATH +yq -i ".device_args.gpu_mapping_key = \"mapping_config1_${WORKER_NUM}\"" $CONFIG_PATH +yq -i ".train_args.group_num = ${GROUP_NUM}" $CONFIG_PATH +yq -i ".train_args.comm_round = ${COMM_ROUND}" $CONFIG_PATH +yq -i ".train_args.group_comm_round = ${GROUP_COMM_ROUND}" $CONFIG_PATH +yq -i ".train_args.group_method = \"${GROUP_METHOD}\"" $CONFIG_PATH +yq -i ".train_args.topo_name = \"${TOPO_NAME}\"" $CONFIG_PATH + +if [ "${GROUP_METHOD}" = "random" ]; then + yq -i ".train_args.group_alpha = 0" $CONFIG_PATH +fi + +if [ "${TOPO_NAME}" != "random" ]; then + yq -i ".train_args.topo_edge_probability = 1.0" $CONFIG_PATH +fi + + +for group_alpha in ${group_alpha_list[@]}; +do + echo "group_alpha=$group_alpha" + yq -i ".train_args.group_alpha = ${group_alpha}" $CONFIG_PATH + + nohup mpirun -np $WORKER_NUM \ + -hostfile mpi_host_file \ + python torch_step_by_step_example.py --cf $CONFIG_PATH \ + > batch_log/"group_alpha=$group_alpha.log" 2>&1 & echo $! >> batch_log/group_alpha.pid + sleep 30 +done + +echo "Finished!" \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml new file mode 100644 index 0000000000..6543f7e033 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config.yaml @@ -0,0 +1,49 @@ +common_args: + training_type: "simulation" + random_seed: 0 + +data_args: + dataset: "mnist" + data_cache_dir: ~/fedml_data + partition_method: "hetero" + partition_alpha: 0.5 + +model_args: + model: "lr" + +train_args: + federated_optimizer: "HierarchicalFL" + client_id_list: "[]" + client_num_in_total: 1000 + client_num_per_round: 20 + comm_round: 20 + epochs: 1 + batch_size: 10 + client_optimizer: sgd + learning_rate: 0.03 + weight_decay: 0.001 + group_method: "random" + group_num: 4 + group_comm_round: 5 + +validation_args: + frequency_of_the_test: 5 + +device_args: + worker_num: 5 + using_gpu: true + gpu_mapping_file: config/mnist_lr/gpu_mapping.yaml + gpu_mapping_key: mapping_config1_5 + +comm_args: + backend: "MPI" + is_mobile: 0 + + +tracking_args: + # When running on MLOps platform(open.fedml.ai), the default log path is at ~/fedml-client/fedml/logs/ and ~/fedml-server/fedml/logs/ + enable_wandb: true + wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 + wandb_project: fedml + run_name: mpi_hierarchical_fl_mnist_lr + wandb_only_server: true \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml new file mode 100644 index 0000000000..43b680d37d --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/fedml_config_topo.yaml @@ -0,0 +1,52 @@ +common_args: + training_type: "simulation" + random_seed: 0 + +data_args: + dataset: "mnist" + data_cache_dir: ~/fedml_data + partition_method: "hetero" + partition_alpha: 0.5 + +model_args: + model: "lr" + +train_args: + federated_optimizer: "HierarchicalFL" + client_id_list: "[]" + client_num_in_total: 1000 + client_num_per_round: 20 + comm_round: 20 + epochs: 1 + batch_size: 10 + client_optimizer: sgd + learning_rate: 0.03 + weight_decay: 0.001 + group_method: "hetero" + group_alpha: 0.5 + group_num: 4 + group_comm_round: 5 + topo_name: "ring" + topo_edge_probability: 0.5 + +validation_args: + frequency_of_the_test: 5 + +device_args: + worker_num: 5 + using_gpu: true + gpu_mapping_file: config/mnist_lr/gpu_mapping.yaml + gpu_mapping_key: mapping_config1_5 + +comm_args: + backend: "MPI" + is_mobile: 0 + + +tracking_args: + # When running on MLOps platform(open.fedml.ai), the default log path is at ~/fedml-client/fedml/logs/ and ~/fedml-server/fedml/logs/ + enable_wandb: true + wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 + wandb_project: fedml + run_name: mpi_hierarchical_fl_mnist_lr + wandb_only_server: true \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/gpu_mapping.yaml b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/gpu_mapping.yaml new file mode 100644 index 0000000000..8c4961681f --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/config/mnist_lr/gpu_mapping.yaml @@ -0,0 +1,70 @@ +# You can define a cluster containing multiple GPUs within multiple machines by defining `gpu_mapping.yaml` as follows: + +# config_cluster0: +# host_name_node0: [num_of_processes_on_GPU0, num_of_processes_on_GPU1, num_of_processes_on_GPU2, num_of_processes_on_GPU3, ..., num_of_processes_on_GPU_n] +# host_name_node1: [num_of_processes_on_GPU0, num_of_processes_on_GPU1, num_of_processes_on_GPU2, num_of_processes_on_GPU3, ..., num_of_processes_on_GPU_n] +# host_name_node_m: [num_of_processes_on_GPU0, num_of_processes_on_GPU1, num_of_processes_on_GPU2, num_of_processes_on_GPU3, ..., num_of_processes_on_GPU_n] + + +# this is used for 10 clients and 1 server training within a single machine which has 4 GPUs +mapping_default: + ChaoyangHe-GPU-RTX2080Tix4: [3, 3, 3, 2] + +mapping_config1_2: + host1: [1, 1] + +mapping_config1_3: + host1: [1, 1, 1] + +# this is used for 4 clients and 1 server training within a single machine which has 4 GPUs +mapping_config1_5: + host1: [2, 1, 1, 1] + +# this is used for 4 clients and 1 server training within a single machine which has 4 GPUs +mapping_config1_6: + host1: [2, 2, 1, 1] + +# this is used for 10 clients and 1 server training within a single machine which has 4 GPUs +mapping_config2_11: + host1: [3, 3, 3, 2] + +# this is used for 10 clients and 1 server training within a single machine which has 8 GPUs +mapping_config3_11: + host1: [2, 2, 2, 1, 1, 1, 1, 1] + +# this is used for 4 clients and 1 server training within a single machine which has 8 GPUs, but you hope to skip the GPU device ID. +mapping_config4_5: + host1: [1, 0, 0, 1, 1, 0, 1, 1] + +# this is used for 4 clients and 1 server training using 6 machines, each machine has 2 GPUs inside, but you hope to use the second GPU. +mapping_config5_6: + host1: [0, 1] + host2: [0, 1] + host3: [0, 1] + host4: [0, 1] + host5: [0, 1] +# this is used for 4 clients and 1 server training using 2 machines, each machine has 2 GPUs inside, but you hope to use the second GPU. +mapping_config5_2: + gpu-worker2: [1,1] + gpu-worker1: [2,1] + +# this is used for 10 clients and 1 server training using 4 machines, each machine has 2 GPUs inside, but you hope to use the second GPU. +mapping_config5_4: + gpu-worker2: [1,1] + gpu-worker1: [2,1] + gpu-worker3: [3,1] + gpu-worker4: [1,1] + +# for grpc GPU mapping +mapping_FedML_gRPC: + hostname_node_server: [1] + hostname_node_1: [1, 0, 0, 0] + hostname_node_2: [1, 0, 0, 0] + +# for torch RPC GPU mapping +mapping_FedML_tRPC: + lambda-server1: [0, 0, 0, 0, 2, 2, 1, 1] + lambda-server2: [2, 1, 1, 1, 0, 0, 0, 0] + +#mapping_FedML_tRPC: +# lambda-server1: [0, 0, 0, 0, 3, 3, 3, 2] \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/mpi_host_file b/python/examples/simulation/mpi_torch_hierarchical_fl/mpi_host_file new file mode 100644 index 0000000000..ebed096720 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/mpi_host_file @@ -0,0 +1 @@ +liuxuezheng3 diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/run_step_by_step_example.sh b/python/examples/simulation/mpi_torch_hierarchical_fl/run_step_by_step_example.sh new file mode 100755 index 0000000000..8dd565e6e8 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/run_step_by_step_example.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +WORKER_NUM=$1 +CONFIG_PATH=$2 + +hostname > mpi_host_file + +mpirun -np $WORKER_NUM \ +-hostfile mpi_host_file \ +python torch_step_by_step_example.py --cf $CONFIG_PATH \ No newline at end of file diff --git a/python/examples/simulation/mpi_torch_hierarchical_fl/torch_step_by_step_example.py b/python/examples/simulation/mpi_torch_hierarchical_fl/torch_step_by_step_example.py new file mode 100644 index 0000000000..aa52f56397 --- /dev/null +++ b/python/examples/simulation/mpi_torch_hierarchical_fl/torch_step_by_step_example.py @@ -0,0 +1 @@ +import fedml from fedml import FedMLRunner if __name__ == "__main__": # init FedML framework args = fedml.init() # init device device = fedml.device.get_device(args) # load data dataset, output_dim = fedml.data.load(args) # load model model = fedml.model.create(args, output_dim) # start training fedml_runner = FedMLRunner(args, device, dataset, model) fedml_runner.run() \ No newline at end of file diff --git a/python/fedml/core/distributed/topology/symmetric_topology_manager.py b/python/fedml/core/distributed/topology/symmetric_topology_manager.py index 07d90525e4..4ce736f2f7 100644 --- a/python/fedml/core/distributed/topology/symmetric_topology_manager.py +++ b/python/fedml/core/distributed/topology/symmetric_topology_manager.py @@ -2,6 +2,7 @@ import numpy as np from .base_topology_manager import BaseTopologyManager +from .topo_utils import * class SymmetricTopologyManager(BaseTopologyManager): @@ -18,6 +19,27 @@ def __init__(self, n, neighbor_num=2): self.neighbor_num = neighbor_num self.topology = [] + def generate_custom_topology(self, args): + topo_name = args.topo_name + if topo_name == 'ring': + self.neighbor_num = 2 + self.generate_topology() + elif topo_name == '2d_torus': + self.topology = get_2d_torus_overlay(self.n) + elif topo_name == 'star': + self.topology = get_star_overlay(self.n) + elif topo_name == 'complete': + self.topology = get_complete_overlay(self.n) + elif topo_name == 'isolated': + self.topology = get_isolated_overlay(self.n) + elif topo_name == 'balanced_tree': + self.topology = get_balanced_tree_overlay(self.n, self.neighbor_num) + elif topo_name == 'random': + probability = args.topo_edge_probability # Probability for edge creation + self.topology = get_random_overlay(self.n, probability) + else: + raise Exception(topo_name) + def generate_topology(self): # first generate a ring topology topology_ring = np.array( @@ -84,8 +106,9 @@ def get_out_neighbor_idx_list(self, node_index): if __name__ == "__main__": # generate a ring topology - tpmgr = SymmetricTopologyManager(6, 2) - tpmgr.generate_topology() + tpmgr = SymmetricTopologyManager(9, 2, 0.3) + # tpmgr.generate_topology() + tpmgr.generate_custom_topology('random') print("tpmgr.topology = " + str(tpmgr.topology)) # get the OUT neighbor weights for node 1 diff --git a/python/fedml/core/distributed/topology/topo_utils.py b/python/fedml/core/distributed/topology/topo_utils.py new file mode 100644 index 0000000000..bba541dbe2 --- /dev/null +++ b/python/fedml/core/distributed/topology/topo_utils.py @@ -0,0 +1,94 @@ +import math +import numpy as np +import networkx as nx + + +def get_2d_torus_overlay(node_num): + side_len = node_num ** 0.5 + assert math.ceil(side_len) == math.floor(side_len) + side_len = int(side_len) + + torus = np.zeros((node_num, node_num), dtype=np.float32) + + for i in range(side_len): + for j in range(side_len): + idx = i * side_len + j + torus[i, i] = 1 / 5 + torus[idx, (((i + 1) % side_len) * side_len + j)] = 1 / 5 + torus[idx, (((i - 1) % side_len) * side_len + j)] = 1 / 5 + torus[idx, (i * side_len + (j + 1) % side_len)] = 1 / 5 + torus[idx, (i * side_len + (j - 1) % side_len)] = 1 / 5 + + return torus + + +def get_star_overlay(node_num): + + star = np.zeros((node_num, node_num), dtype=np.float32) + for i in range(node_num): + if i == 0: + star[i, i] = 1 / node_num + else: + star[0, i] = star[i, 0] = 1 / node_num + star[i, i] = 1 - 1 / node_num + + return star + + +def get_complete_overlay(node_num): + + complete = np.ones((node_num, node_num), dtype=np.float32) + complete /= node_num + + return complete + + +def get_isolated_overlay(node_num): + + isolated = np.zeros((node_num, node_num), dtype=np.float32) + + for i in range(node_num): + isolated[i, i] = 1 + + return isolated + + +def get_balanced_tree_overlay(node_num, degree=2): + + tree = np.zeros((node_num, node_num), dtype=np.float32) + + for i in range(node_num): + for j in range(1, degree+1): + k = i * 2 + j + if k >= node_num: + break + tree[i, k] = 1 / (degree+1) + + for i in range(node_num): + tree[i, i] = 1 - tree[i, :].sum() + + return tree + + +def get_barbell_overlay(node_num, m1=1, m2=0): + + barbell = None + + return barbell + + +def get_random_overlay(node_num, probability=0.5): + + random = np.array( + nx.to_numpy_matrix(nx.fast_gnp_random_graph(node_num, probability)), dtype=np.float32 + ) + + matrix_sum = random.sum(1) + + for i in range(node_num): + for j in range(node_num): + if i != j and random[i, j] > 0: + random[i, j] = 1 / (1 + max(matrix_sum[i], matrix_sum[j])) + random[i, i] = 1 - random[i].sum() + + return random \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py b/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py new file mode 100644 index 0000000000..62c25ba683 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierClient.py @@ -0,0 +1,50 @@ +import copy + +import torch +import torch.nn as nn + +from ...sp.fedavg.client import Client + + +class HFLClient(Client): + def __init__(self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model, + model_trainer): + + super().__init__(client_idx, local_training_data, local_test_data, local_sample_number, args, device, + model_trainer) + self.client_idx = client_idx + self.local_training_data = local_training_data + self.local_test_data = local_test_data + self.local_sample_number = local_sample_number + + self.args = args + self.device = device + self.model = model + self.model_trainer = model_trainer + self.criterion = nn.CrossEntropyLoss().to(device) + + def train(self, w, scaled_loss_factor=1.0): + self.model.load_state_dict(w) + self.model.to(self.device) + + scaled_loss_factor = min(scaled_loss_factor, 1.0) + if self.args.client_optimizer == "sgd": + optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate * scaled_loss_factor) + else: + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.args.learning_rate * scaled_loss_factor, + weight_decay=self.args.weight_decay, + amsgrad=True, + ) + + for epoch in range(self.args.epochs): + for x, labels in self.local_training_data: + x, labels = x.to(self.device), labels.to(self.device) + self.model.zero_grad() + log_probs = self.model(x) + loss = self.criterion(log_probs, labels) # pylint: disable=E1102 + loss.backward() + optimizer.step() + + return copy.deepcopy(self.model.cpu().state_dict()) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py new file mode 100644 index 0000000000..d3095d4607 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgAPI.py @@ -0,0 +1,191 @@ +from .HierFedAvgCloudAggregator import HierFedAVGCloudAggregator +from .HierFedAvgCloudManager import HierFedAVGCloudManager +from .HierFedAvgEdgeManager import HierFedAVGEdgeManager +from .HierGroup import HierGroup +from .utils import analyze_clients_type, hetero_partition_groups, stats_group +from ....core import ClientTrainer, ServerAggregator +from ....core.dp.fedml_differential_privacy import FedMLDifferentialPrivacy +from ....core.security.fedml_attacker import FedMLAttacker +from ....core.security.fedml_defender import FedMLDefender +from ....ml.aggregator.aggregator_creator import create_server_aggregator +from ....ml.trainer.trainer_creator import create_model_trainer +from ....core.distributed.topology.symmetric_topology_manager import SymmetricTopologyManager + + +import numpy as np +import wandb + +def FedML_HierFedAvg_distributed( + args, + process_id, + worker_number, + comm, + device, + dataset, + model, + client_trainer: ClientTrainer = None, + server_aggregator: ServerAggregator = None, +): + [ + train_data_num, + test_data_num, + train_data_global, + test_data_global, + train_data_local_num_dict, + train_data_local_dict, + test_data_local_dict, + class_num, + ] = dataset + + FedMLAttacker.get_instance().init(args) + FedMLDefender.get_instance().init(args) + FedMLDifferentialPrivacy.get_instance().init(args) + + if process_id == 0: + init_cloud_server( + args, + device, + comm, + process_id, + worker_number, + model, + train_data_num, + train_data_global, + test_data_global, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + server_aggregator, + class_num + ) + else: + init_edge_server_clients( + args, + device, + comm, + process_id, + worker_number, + model, + train_data_num, + train_data_local_num_dict, + train_data_local_dict, + test_data_local_dict, + client_trainer, + server_aggregator + ) + + +def init_cloud_server( + args, + device, + comm, + rank, + size, + model, + train_data_num, + train_data_global, + test_data_global, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + server_aggregator, + class_num +): + if server_aggregator is None: + server_aggregator = create_server_aggregator(model, args) + server_aggregator.set_id(-1) + + worker_num = size - 1 + + # set up topology + topology_manager = None + if hasattr(args, "topo_name"): + topology_manager = SymmetricTopologyManager(worker_num, args) + topology_manager.generate_custom_topology(args) + + # aggregator + aggregator = HierFedAVGCloudAggregator( + train_data_global, + test_data_global, + train_data_num, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + worker_num, + device, + args, + server_aggregator + ) + + # start the distributed training + backend = args.backend + group_indexes, group_to_client_indexes = setup_clients(args, train_data_local_dict, class_num) + + # print group detail + stats_group(group_to_client_indexes, train_data_local_dict, train_data_local_num_dict, class_num, args) + + server_manager = HierFedAVGCloudManager(args, aggregator, group_indexes, group_to_client_indexes, + comm, rank, size, backend, topology_manager) + server_manager.send_init_msg() + server_manager.run() + + +def init_edge_server_clients( + args, + device, + comm, + process_id, + size, + model, + train_data_num, + train_data_local_num_dict, + train_data_local_dict, + test_data_local_dict, + group, + model_trainer=None, +): + + if model_trainer is None: + model_trainer = create_model_trainer(model, args) + + edge_index = process_id - 1 + backend = args.backend + + # Client assignment is decided on cloud server and the information will be communicated later + group = HierGroup( + edge_index, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + args, + device, + model, + model_trainer + ) + + edge_manager = HierFedAVGEdgeManager(group, args, comm, process_id, size, backend) + edge_manager.run() + + +def setup_clients( + args, + train_data_local_dict, + class_num + ): + + if args.group_method == "random": + group_indexes = np.random.randint( + 0, args.group_num, args.client_num_in_total + ) + group_to_client_indexes = {} + for client_idx, group_idx in enumerate(group_indexes): + if not group_idx in group_to_client_indexes: + group_to_client_indexes[group_idx] = [] + group_to_client_indexes[group_idx].append(client_idx) + elif args.group_method == "hetero": + clients_type_list = analyze_clients_type(train_data_local_dict, class_num, num_type=args.group_num) + group_indexes, group_to_client_indexes = hetero_partition_groups(clients_type_list, + args.group_num, + alpha=args.group_alpha) + + return group_indexes, group_to_client_indexes \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py new file mode 100644 index 0000000000..ee23e14bff --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudAggregator.py @@ -0,0 +1,257 @@ +import copy +import logging +import random +import time +import numpy as np +import torch + +from ....core.security.fedml_attacker import FedMLAttacker +from ....core.security.fedml_defender import FedMLDefender +from .utils import cal_mixing_consensus_speed + + +class HierFedAVGCloudAggregator(object): + def __init__( + self, + train_global, + test_global, + all_train_data_num, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + worker_num, + device, + args, + server_aggregator, + ): + self.aggregator = server_aggregator + self.args = args + self.train_global = train_global + self.test_global = test_global + self.val_global = self._generate_validation_set() + self.all_train_data_num = all_train_data_num + + self.train_data_local_dict = train_data_local_dict + self.test_data_local_dict = test_data_local_dict + self.train_data_local_num_dict = train_data_local_num_dict + + self.worker_num = worker_num + self.device = device + self.model_dict = dict() + self.sample_num_dict = dict() + self.flag_client_model_uploaded_dict = dict() + for idx in range(self.worker_num): + self.flag_client_model_uploaded_dict[idx] = False + + def get_global_model_params(self): + return self.aggregator.get_model_params() + + def set_global_model_params(self, model_parameters): + self.aggregator.set_model_params(model_parameters) + + def add_local_trained_result(self, index, model_params_list, sample_num): + logging.info("add_model. index = %d" % index) + self.model_dict[index] = model_params_list + self.sample_num_dict[index] = sample_num + self.flag_client_model_uploaded_dict[index] = True + + def check_whether_all_receive(self): + logging.debug("worker_num = {}".format(self.worker_num)) + for idx in range(self.worker_num): + if not self.flag_client_model_uploaded_dict[idx]: + return False + for idx in range(self.worker_num): + self.flag_client_model_uploaded_dict[idx] = False + return True + + def aggregate(self): + start_time = time.time() + + # Edge server may conduct partial aggregation multiple times, so cloud server will receive a model list + group_comm_round = len(self.sample_num_dict[0]) + + for group_round_idx in range(group_comm_round): + model_list = [] + global_round_idx = self.model_dict[0][group_round_idx][0] + + for idx in range(0, self.worker_num): + model_list.append((self.sample_num_dict[idx][group_round_idx], + self.model_dict[idx][group_round_idx][1])) + + averaged_params = self._fedavg_aggregation_(model_list) + self.set_global_model_params(averaged_params) + self.test_on_cloud_for_all_clients(global_round_idx) + + if FedMLAttacker.get_instance().is_model_attack(): + model_list = FedMLAttacker.get_instance().attack_model(raw_client_grad_list=model_list, extra_auxiliary_info=None) + + if FedMLDefender.get_instance().is_defense_enabled(): + # todo: update extra_auxiliary_info according to defense type + averaged_params = FedMLDefender.get_instance().defend( + raw_client_grad_list=model_list, + base_aggregation_func=self._fedavg_aggregation_, + extra_auxiliary_info=self.get_global_model_params(), + ) + else: + averaged_params = self._fedavg_aggregation_(model_list) + + # update the global model which is cached in the cloud + self.set_global_model_params(averaged_params) + + end_time = time.time() + logging.info("aggregate time cost: %d" % (end_time - start_time)) + return averaged_params + + def mix(self, topology_manager): + start_time = time.time() + + # Edge server may conduct partial aggregation multiple times, so cloud server will receive a model list + group_comm_round = len(self.sample_num_dict[0]) + edge_model_list = [None for _ in range(self.worker_num)] + + p = cal_mixing_consensus_speed(topology_manager.topology, self.model_dict[0][0][0], self.args) + + for group_round_idx in range(group_comm_round): + model_list = [] + global_round_idx = self.model_dict[0][group_round_idx][0] + + for idx in range(self.worker_num): + model_list.append((self.sample_num_dict[idx][group_round_idx], + self.model_dict[idx][group_round_idx][1])) + + # mixing between neighbors + for idx in range(self.worker_num): + edge_model_list[idx] = (self.sample_num_dict[idx][group_round_idx], + self._pfedavg_mixing_(model_list, + topology_manager.get_in_neighbor_weights(idx)) + ) + # average for testing + averaged_params = self._pfedavg_aggregation_(edge_model_list) + self.set_global_model_params(averaged_params) + self.test_on_cloud_for_all_clients(global_round_idx) + + # update the global model which is cached in the cloud + self.set_global_model_params(averaged_params) + + end_time = time.time() + logging.info("mix time cost: %d" % (end_time - start_time)) + return [edge_model for _, edge_model in edge_model_list] + + def _fedavg_aggregation_(self, model_list): + training_num = 0 + for i in range(0, len(model_list)): + local_sample_number, local_model_params = model_list[i] + training_num += local_sample_number + (num0, averaged_params) = model_list[0] + for k in averaged_params.keys(): + for i in range(0, len(model_list)): + local_sample_number, local_model_params = model_list[i] + if i == 0: + averaged_params[k] = ( + local_model_params[k] * local_sample_number / training_num + ) + else: + averaged_params[k] += ( + local_model_params[k] * local_sample_number / training_num + ) + return averaged_params + + def _pfedavg_aggregation_(self, model_list): + (num0, averaged_params) = model_list[0] + for k in averaged_params.keys(): + for i in range(0, len(model_list)): + _, local_model_params = model_list[i] + if i == 0: + averaged_params[k] = ( + local_model_params[k] * 1 / len(model_list) + ) + else: + averaged_params[k] += ( + local_model_params[k] * 1 / len(model_list) + ) + return averaged_params + + def _pfedavg_mixing_(self, model_list, neighbor_topo_weight_list): + training_num = 0 + for i in range(0, len(model_list)): + local_sample_number, local_model_params = model_list[i] + training_num += local_sample_number + + (num0, averaged_params) = model_list[0] + averaged_params = copy.deepcopy(averaged_params) + for k in averaged_params.keys(): + for i in range(0, len(model_list)): + local_sample_number, local_model_params = model_list[i] + topo_weight = neighbor_topo_weight_list[i] + if i == 0: + averaged_params[k] = ( + local_model_params[k] * topo_weight + ) + else: + averaged_params[k] += ( + local_model_params[k] * topo_weight + ) + + return averaged_params + + def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + if client_num_in_total == client_num_per_round: + client_indexes = [ + client_index for client_index in range(client_num_in_total) + ] + else: + num_clients = min(client_num_per_round, client_num_in_total) + np.random.seed( + round_idx + ) # make sure for each comparison, we are selecting the same clients each round + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False + ) + logging.info("client_indexes = %s" % str(client_indexes)) + return client_indexes + + def _generate_validation_set(self, num_samples=10000): + if self.args.dataset.startswith("stackoverflow"): + test_data_num = len(self.test_global.dataset) + sample_indices = random.sample( + range(test_data_num), min(num_samples, test_data_num) + ) + subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) + sample_testset = torch.utils.data.DataLoader( + subset, batch_size=self.args.batch_size + ) + return sample_testset + else: + return self.test_global + + def test_on_cloud_for_all_clients(self, global_round_idx): + if self.aggregator.test_all( + self.train_data_local_dict, + self.test_data_local_dict, + self.device, + self.args, + ): + return + + if ( + global_round_idx % self.args.frequency_of_the_test == 0 + or global_round_idx == self.args.comm_round * self.args.group_comm_round - 1 + ): + + logging.info("################test_on_cloud_for_all_clients : {}".format(global_round_idx)) + + # We may want to test the intermediate results of partial aggregated models, so we play a trick and let + # args.round_idx be total number of partial aggregated times + + round_idx = self.args.round_idx + self.args.round_idx = global_round_idx + + if global_round_idx == self.args.comm_round - 1: + # we allow to return four metrics, such as accuracy, AUC, loss, etc. + metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) + else: + metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) + + self.args.round_idx = round_idx + + logging.info("metric_result_in_current_round = {}".format(metric_result_in_current_round)) \ No newline at end of file diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py new file mode 100644 index 0000000000..6111f44c4e --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgCloudManager.py @@ -0,0 +1,170 @@ +import logging + +from .message_define import MyMessage +from ....core.distributed.fedml_comm_manager import FedMLCommManager +from ....core.distributed.communication.message import Message +from .utils import post_complete_message_to_sweep_process + +class HierFedAVGCloudManager(FedMLCommManager): + def __init__( + self, + args, + aggregator, + group_indexes, + group_to_client_indexes, + comm=None, + rank=0, + size=0, + backend="MPI", + topology_manager=None + # is_preprocessed=False, + # preprocessed_client_lists=None, + ): + super().__init__(args, comm, rank, size, backend) + self.args = args + self.aggregator = aggregator + self.group_indexes = group_indexes + self.group_to_client_indexes = group_to_client_indexes + self.round_num = args.comm_round + self.args.round_idx = 0 + self.topology_manager = topology_manager + + total_clients = len(self.group_indexes) + self.group_to_client_num_per_round = [ + args.client_num_per_round * len(self.group_to_client_indexes[i]) // total_clients + for i in range(args.group_num) + ] + + remain_client_num_list_per_round = args.client_num_per_round - sum(self.group_to_client_num_per_round) + while remain_client_num_list_per_round > 0: + self.group_to_client_num_per_round[remain_client_num_list_per_round-1] += 1 + remain_client_num_list_per_round -= 1 + + # self.is_preprocessed = is_preprocessed + # self.preprocessed_client_lists = preprocessed_client_lists + + def run(self): + super().run() + + def send_init_msg(self): + # broadcast to edge servers + global_model_params = self.aggregator.get_global_model_params() + + sampled_group_to_client_indexes = {} + total_sampled_data_size = 0 + for group_idx, client_num_per_round in enumerate(self.group_to_client_num_per_round): + client_num_in_total = len(self.group_to_client_indexes[group_idx]) + sampled_client_indexes = self.aggregator.client_sampling( + self.args.round_idx, + client_num_in_total, + client_num_per_round, + ) + sampled_group_to_client_indexes[group_idx] = [] + for index in sampled_client_indexes: + client_idx = self.group_to_client_indexes[group_idx][index] + sampled_group_to_client_indexes[group_idx].append(client_idx) + total_sampled_data_size += self.aggregator.train_data_local_num_dict[client_idx] + + logging.info( + "client_indexes of each group = {}".format(sampled_group_to_client_indexes) + ) + + for process_id in range(1, self.size): + total_sampled_data_size = 0 if self.topology_manager is None else total_sampled_data_size + self.send_message_init_config( + process_id, + global_model_params, + self.group_to_client_indexes[process_id - 1], + sampled_group_to_client_indexes[process_id - 1], + total_sampled_data_size, + process_id - 1 + ) + + def register_message_receive_handlers(self): + self.register_message_receive_handler( + MyMessage.MSG_TYPE_E2C_SEND_MODEL_TO_CLOUD, + self.handle_message_receive_model_from_edge, + ) + + def handle_message_receive_model_from_edge(self, msg_params): + sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) + model_params_list = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_LIST) + sample_num_list = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) + + self.aggregator.add_local_trained_result( + sender_id - 1, model_params_list, sample_num_list + ) + b_all_received = self.aggregator.check_whether_all_receive() + logging.info("b_all_received = " + str(b_all_received)) + if b_all_received: + # If topology_manage is None, it is simple average. Otherwise, it is mixing between neighbours. + if self.topology_manager is None: + global_model_params = self.aggregator.aggregate() + else: + global_model_params_list = self.aggregator.mix(self.topology_manager) + + # start the next round + self.args.round_idx += 1 + if self.args.round_idx == self.round_num: + post_complete_message_to_sweep_process(self.args) + self.finish() + return + + sampled_group_to_client_indexes = {} + total_sampled_data_size = 0 + for group_idx, client_num_per_round in enumerate(self.group_to_client_num_per_round): + client_num_in_total = len(self.group_to_client_indexes[group_idx]) + sampled_client_indexes = self.aggregator.client_sampling( + self.args.round_idx, + client_num_in_total, + client_num_per_round, + ) + sampled_group_to_client_indexes[group_idx] = [] + for index in sampled_client_indexes: + client_idx = self.group_to_client_indexes[group_idx][index] + sampled_group_to_client_indexes[group_idx].append(client_idx) + total_sampled_data_size += self.aggregator.train_data_local_num_dict[client_idx] + + logging.info( + "client_indexes of each group = {}".format(sampled_group_to_client_indexes) + ) + + for receiver_id in range(1, self.size): + if self.topology_manager is not None: + global_model_params = global_model_params_list[receiver_id - 1] + else: + total_sampled_data_size = 0 + self.send_message_sync_model_to_edge( + receiver_id, + global_model_params, + sampled_group_to_client_indexes[receiver_id - 1], + total_sampled_data_size, + receiver_id - 1 + ) + + def send_message_init_config(self, receive_id, global_model_params, total_client_indexes, + sampled_client_indexed, total_sampled_data_size, edge_index): + message = Message( + MyMessage.MSG_TYPE_C2E_INIT_CONFIG, self.get_sender_id(), receive_id + ) + message.add_params(MyMessage.MSG_ARG_KEY_TOTAL_EDGE_CLIENTS, total_client_indexes) + message.add_params(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS, sampled_client_indexed) + message.add_params(MyMessage.MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE, total_sampled_data_size) + message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params(MyMessage.MSG_ARG_KEY_EDGE_INDEX, str(edge_index)) + self.send_message(message) + + def send_message_sync_model_to_edge( + self, receive_id, global_model_params, sampled_client_indexed, total_sampled_data_size, edge_index + ): + logging.info("send_message_sync_model_to_edge. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_C2E_SYNC_MODEL_TO_EDGE, + self.get_sender_id(), + receive_id, + ) + message.add_params(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS, sampled_client_indexed) + message.add_params(MyMessage.MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE, total_sampled_data_size) + message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params(MyMessage.MSG_ARG_KEY_EDGE_INDEX, str(edge_index)) + self.send_message(message) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py new file mode 100644 index 0000000000..62a4efb106 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierFedAvgEdgeManager.py @@ -0,0 +1,74 @@ +import logging + +from .message_define import MyMessage +from ....core.distributed.fedml_comm_manager import FedMLCommManager +from ....core.distributed.communication.message import Message +from .utils import post_complete_message_to_sweep_process + + +class HierFedAVGEdgeManager(FedMLCommManager): + def __init__( + self, + group, + args, + comm=None, + rank=0, + size=0, + backend="MPI", + ): + super().__init__(args, comm, rank, size, backend) + self.num_rounds = args.comm_round + self.args.round_idx = 0 + self.group =group + + def run(self): + super().run() + + def register_message_receive_handlers(self): + self.register_message_receive_handler( + MyMessage.MSG_TYPE_C2E_INIT_CONFIG, self.handle_message_init + ) + self.register_message_receive_handler( + MyMessage.MSG_TYPE_C2E_SYNC_MODEL_TO_EDGE, + self.handle_message_receive_model_from_cloud, + ) + + def handle_message_init(self, msg_params): + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) + total_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_TOTAL_EDGE_CLIENTS) + sampled_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS) + total_sampled_data_size = msg_params.get(MyMessage.MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE) + edge_index = msg_params.get(MyMessage.MSG_ARG_KEY_EDGE_INDEX) + + self.group.setup_clients(total_client_indexes) + self.args.round_idx = 0 + w_group_list, sample_num_list = self.group.train(self.args.round_idx, global_model_params, + sampled_client_indexes, total_sampled_data_size) + + self.send_model_to_cloud(0, w_group_list, sample_num_list) + + def handle_message_receive_model_from_cloud(self, msg_params): + logging.info("handle_message_receive_model_from_cloud.") + sampled_client_indexes = msg_params.get(MyMessage.MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS) + total_sampled_data_size = msg_params.get(MyMessage.MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE) + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) + edge_index = msg_params.get(MyMessage.MSG_ARG_KEY_EDGE_INDEX) + + self.args.round_idx += 1 + w_group_list, sample_num_list = self.group.train(self.args.round_idx, global_model_params, + sampled_client_indexes, total_sampled_data_size) + self.send_model_to_cloud(0, w_group_list, sample_num_list) + + if self.args.round_idx == self.num_rounds: + post_complete_message_to_sweep_process(self.args) + self.finish() + + def send_model_to_cloud(self, receive_id, w_group_list, edge_sample_num): + message = Message( + MyMessage.MSG_TYPE_E2C_SEND_MODEL_TO_CLOUD, + self.get_sender_id(), + receive_id, + ) + message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_LIST, w_group_list) + message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, edge_sample_num) + self.send_message(message) diff --git a/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py b/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py new file mode 100644 index 0000000000..ef71d5e320 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/HierGroup.py @@ -0,0 +1,81 @@ +import logging + +from .HierClient import HFLClient +from ...sp.fedavg.fedavg_api import FedAvgAPI + + +class HierGroup(FedAvgAPI): + def __init__( + self, + idx, + train_data_local_dict, + test_data_local_dict, + train_data_local_num_dict, + args, + device, + model, + model_trainer, + ): + self.idx = idx + self.args = args + self.device = device + self.client_dict = {} + self.train_data_local_num_dict = train_data_local_num_dict + self.train_data_local_dict = train_data_local_dict + self.test_data_local_dict = test_data_local_dict + self.model = model + self.model_trainer = model_trainer + self.args = args + + def setup_clients(self, total_client_indexes): + self.client_dict = {} + for client_idx in total_client_indexes: + self.client_dict[client_idx] = HFLClient( + client_idx, + self.train_data_local_dict[client_idx], + self.test_data_local_dict[client_idx], + self.train_data_local_num_dict[client_idx], + self.args, + self.device, + self.model, + self.model_trainer, + ) + + def get_sample_number(self, sampled_client_indexes): + self.group_sample_number = 0 + for client_idx in sampled_client_indexes: + self.group_sample_number += self.train_data_local_num_dict[client_idx] + return self.group_sample_number + + def train(self, round_idx, w, sampled_client_indexes, total_sampled_data_size=0): + sampled_client_list = [self.client_dict[client_idx] for client_idx in sampled_client_indexes] + w_group = w + w_group_list = [] + sample_num_list = [] + for group_round_idx in range(self.args.group_comm_round): + logging.info("Group ID : {} / Group Communication Round : {}".format(self.idx, group_round_idx)) + w_locals = [] + + global_round_idx = ( + round_idx * self.args.group_comm_round + + group_round_idx + ) + # train each client + for client in sampled_client_list: + if total_sampled_data_size > 0: + scaled_loss_factor = ( + self.args.group_num * len(sampled_client_list) + * client.local_sample_number / total_sampled_data_size + ) + w_local = client.train(w_group, scaled_loss_factor) + else: + w_local = client.train(w_group) + w_locals.append((client.get_sample_number(), w_local)) + + # aggregate local weights + w_group_list.append((global_round_idx, self._aggregate(w_locals))) + sample_num_list.append(self.get_sample_number(sampled_client_indexes)) + + # update the group weight + w_group = w_group_list[-1][1] + return w_group_list, sample_num_list diff --git a/python/fedml/simulation/mpi/hierarchical_fl/__init__.py b/python/fedml/simulation/mpi/hierarchical_fl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fedml/simulation/mpi/hierarchical_fl/message_define.py b/python/fedml/simulation/mpi/hierarchical_fl/message_define.py new file mode 100644 index 0000000000..f5d6cd1c73 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/message_define.py @@ -0,0 +1,35 @@ +class MyMessage(object): + """ + message type definition + """ + + # cloud to edge + MSG_TYPE_C2E_INIT_CONFIG = 1 + MSG_TYPE_C2E_SYNC_MODEL_TO_EDGE = 2 + + # edge to cloud + MSG_TYPE_E2C_SEND_MODEL_TO_CLOUD = 3 + MSG_TYPE_E2C_SEND_STATS_TO_CLOUD = 4 + + MSG_ARG_KEY_TYPE = "msg_type" + MSG_ARG_KEY_SENDER = "sender" + MSG_ARG_KEY_RECEIVER = "receiver" + + """ + message payload keywords definition + """ + MSG_ARG_KEY_NUM_SAMPLES = "num_samples" + MSG_ARG_KEY_MODEL_PARAMS = "model_params" + MSG_ARG_KEY_MODEL_PARAMS_LIST = "model_params_list" + MSG_ARG_KEY_EDGE_INDEX = "edge_idx" + MSG_ARG_KEY_TOTAL_EDGE_CLIENTS = "total_edge_clients" + MSG_ARG_KEY_SAMPLED_EDGE_CLIENTS = "sampled_edge_clients" + MSG_ARG_KEY_TOTAL_SAMPLED_DATA_SIZE = "total_sampled_data_size" + + MSG_ARG_KEY_TRAIN_CORRECT = "train_correct" + MSG_ARG_KEY_TRAIN_ERROR = "train_error" + MSG_ARG_KEY_TRAIN_NUM = "train_num_sample" + + MSG_ARG_KEY_TEST_CORRECT = "test_correct" + MSG_ARG_KEY_TEST_ERROR = "test_error" + MSG_ARG_KEY_TEST_NUM = "test_num_sample" diff --git a/python/fedml/simulation/mpi/hierarchical_fl/utils.py b/python/fedml/simulation/mpi/hierarchical_fl/utils.py new file mode 100644 index 0000000000..4f7987ec35 --- /dev/null +++ b/python/fedml/simulation/mpi/hierarchical_fl/utils.py @@ -0,0 +1,139 @@ +import os +import time + +import numpy as np +import torch +import wandb +import logging + +from sklearn.cluster import KMeans + + +def cal_mixing_consensus_speed(topo_weight_matrix, global_round_idx, args): + n_rows, n_cols = np.shape(topo_weight_matrix) + assert n_rows == n_cols + A = np.array(topo_weight_matrix) - 1 / n_rows + p = 1 - np.linalg.norm(A, ord=2) ** 2 + if args.enable_wandb: + wandb.log({"Groups/p": p, "comm_round": global_round_idx}) + return p + + +def stats_group(group_to_client_indexes, train_data_local_dict, train_data_local_num_dict, class_num, args): + + xs = [i for i in range(class_num)] + ys = [] + keys = [] + for group_idx in range(len(group_to_client_indexes)): + data_size = 0 + group_y_train = [] + for client_id in group_to_client_indexes[group_idx]: + data_size += train_data_local_num_dict[client_id] + y_train = torch.concat([y for _, y in train_data_local_dict[client_id]]).tolist() + group_y_train.extend(y_train) + + labels, counts = np.unique(group_y_train, return_counts=True) + + count_vector = np.zeros(class_num) + count_vector[labels] = counts + ys.append(count_vector/count_vector.sum()) + keys.append("Group {}".format(group_idx)) + + if args.enable_wandb: + wandb.log({"Groups/Client_num": len(group_to_client_indexes[group_idx]), "group_id": group_idx}) + wandb.log({"Groups/Data_size": data_size, "group_id": group_idx}) + + logging.info("Group {}: client num={}, data size={} ".format( + group_idx, + len(group_to_client_indexes[group_idx]), + data_size + )) + + if args.enable_wandb: + wandb.log({"Groups/Data_distribution": + wandb.plot.line_series(xs=xs, ys=ys, keys=keys, title="Data distribution", xname="Label")} + ) + + +def hetero_partition_groups(clients_type_list, num_groups, alpha=0.5): + min_size = 0 + num_type = np.unique(clients_type_list).size + N = len(clients_type_list) + group_to_client_indexes = {} + while min_size < 10: + idx_batch = [[] for _ in range(num_groups)] + # for each type in clients + for k in range(num_type): + idx_k = np.where(np.array(clients_type_list) == k)[0] + np.random.shuffle(idx_k) + proportions = np.random.dirichlet(np.repeat(alpha, num_groups)) + ## Balance + proportions = np.array([p * (len(idx_j) < N / num_groups) for p, idx_j in zip(proportions, idx_batch)]) + proportions = proportions / proportions.sum() + proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] + idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] + min_size = min([len(idx_j) for idx_j in idx_batch]) + + group_indexes = [0 for _ in range(N)] + for j in range(num_groups): + np.random.shuffle(idx_batch[j]) + group_to_client_indexes[j] = idx_batch[j] + for client_id in group_to_client_indexes[j]: + group_indexes[client_id] = j + + return group_indexes, group_to_client_indexes + + +def analyze_clients_type(train_data_local_dict, class_num, num_type=5): + client_feature_list = [] + for i in range(len(train_data_local_dict)): + y_train = torch.concat([y for _, y in train_data_local_dict[i]]) + labels, counts = torch.unique(y_train, return_counts=True) + data_feature = np.zeros(class_num) + total = 0 + for label, count in zip(labels, counts): + data_feature[label.item()] = count.item() + total += count.item() + data_feature /= total + client_feature_list.append(data_feature) + + kmeans = KMeans(n_clusters=num_type, random_state=0, n_init="auto").fit(client_feature_list) + + + + # for k in range(num_type): + # tmp = [] + # for i, j in enumerate(kmeans.labels_): + # if j == k: + # indexes = np.where(np.array(client_feature_list[i]) > 0) + # tmp.extend(indexes[0].tolist()) + # print(np.unique(tmp)) + # + # exit(0) + return kmeans.labels_ + + +def transform_list_to_tensor(model_params_list): + for k in model_params_list.keys(): + model_params_list[k] = torch.from_numpy( + np.asarray(model_params_list[k]) + ).float() + return model_params_list + + +def transform_tensor_to_list(model_params): + for k in model_params.keys(): + model_params[k] = model_params[k].detach().numpy().tolist() + return model_params + + +def post_complete_message_to_sweep_process(args): + pipe_path = "./tmp/fedml" + os.system("mkdir -p ./tmp/; touch ./tmp/fedml") + if not os.path.exists(pipe_path): + os.mkfifo(pipe_path) + pipe_fd = os.open(pipe_path, os.O_WRONLY) + + with os.fdopen(pipe_fd, "w") as pipe: + pipe.write("training is finished! \n%s\n" % (str(args))) + time.sleep(3) diff --git a/python/fedml/simulation/simulator.py b/python/fedml/simulation/simulator.py index abf0394869..558d0f6aa9 100644 --- a/python/fedml/simulation/simulator.py +++ b/python/fedml/simulation/simulator.py @@ -90,6 +90,7 @@ def __init__( from .mpi.fedavg_seq.FedAvgSeqAPI import FedML_FedAvgSeq_distributed from .mpi.async_fedavg.AsyncFedAvgSeqAPI import FedML_Async_distributed from .mpi.fednova.FedNovaAPI import FedML_FedNova_distributed + from .mpi.hierarchical_fl.HierFedAvgAPI import FedML_HierFedAvg_distributed if args.federated_optimizer == FedML_FEDERATED_OPTIMIZER_FEDAVG: FedML_FedAvg_distributed( @@ -159,6 +160,18 @@ def __init__( SplitNN_distributed( args.process_id, args.worker_num, device, args.comm, model, dataset=dataset, args=args, ) + elif args.federated_optimizer == FedML_FEDERATED_OPTIMIZER_HIERACHICAL_FL: + FedML_HierFedAvg_distributed( + args, + args.process_id, + args.worker_num, + args.comm, + device, + dataset, + model, + client_trainer=client_trainer, + server_aggregator=server_aggregator + ) elif args.federated_optimizer == FedML_FEDERATED_OPTIMIZER_DECENTRALIZED_FL: FedML_Decentralized_Demo_distributed(args, args.process_id, args.worker_num, args.comm) elif args.federated_optimizer == FedML_FEDERATED_OPTIMIZER_FEDGAN: