This project demonstrates a federated learning setup using the Flower framework and PyTorch with the CelebA dataset. The experiment involves both IID and non-IID data splits and uses a pre-trained MobileNetV2 model for efficient training.
- Data Partitioning: Generate IID and non-IID distributions for federated training .
- Pre-trained Model: Leverage MobileNetV2 with a frozen feature extractor for computational efficiency.
- Federated Learning Simulation: Simulate federated training with 50 clients over 10 communication rounds across demographic groups.
- Evaluation: Perform analysis using a classification report, confusion matrix and learning curves.
- Real-World Federated Learning Execution: Implement federated learning in a real-world setting using gRPC for seamless client-server communication, replicating practical deployment scenarios.
This project requires Python 3.9 or higher. The complete list of dependencies is in requirements.txt
.
- Install Python 3.9 or higher.
- Install the required dependencies:
pip install -r requirements.txt
federated-learning-celeba/
├── configs/
│ ├── client_id.py
│ ├── config.py
├── data/
│ ├── data_loader.py
│ ├── fed_dataset.py
│ └── partitioner/
│ ├── visual_dirichlet_partitioner.py
│ └── iid_partitioner.py
├── model/
│ └── mobilenetv2.py
├── results/
│ ├── Federated_Learning_with_Flower_and_CelebA.ipynb
├── strategies/
│ ├── FedAvg.py
├── client.py
├── main.py
├── README.md
├── requirements.txt
├── server.py
└── task.py
configs/
: Configuration files to manage both global and client-specific settings.client_id.py
: Defines unique identifiers and configurations for each client.config.py
: Centralized configuration file for global settings.
data/
: Scripts and utilities for data loading, preprocessing, and partitioning.data_loader.py
: Handles the loading and preprocessing of the CelebA dataset.fed_dataset.py
: Defines theFedDataset
class for creating federated datasets.partitioner/
: Implements strategies for data partitioning.visual_dirichlet_partitioner.py
: Creates non-IID distributions using a Dirichlet distribution.iid_partitioner.py
: Creates IID partitions for federated learning experiments.
model/
: Contains the definition of the training model.mobilenetv2.py
: Implements a MobileNetV2 model with a frozen feature extractor and customizable classifier head.
results/
: Stores results and logs from experiments.Federated_Learning_with_Flower_and_CelebA.ipynb
: Contains detailed performance assessments and visualizations.
strategies/
: Includes custom federated learning strategies.FedAvg.py
: Implementation of the Federated Averaging algorithm.
client.py
: Defines the behavior of federated clients, including local training and communication with the server.server.py
: Manages the federated learning server, including aggregation, coordination, and evaluation.task.py
: Contains task-specific logic for model training, testing, and evaluation.main.py
: Central script to orchestrate and execute the federated learning experiment.requirements.txt
: Lists all dependencies required to run the project.README.md
: Provides documentation, including project overview, usage instructions, and descriptions of components.
The first step in the project is to load the CelebA dataset, add demographic labels, and partition the dataset into client-specific subsets for federated learning. This step allows flexible data distribution strategies, such as IID and non-IID, to simulate diverse real-world scenarios.
-
Load the CelebA Dataset
The CelebA dataset is loaded using thedatasets
library, making it easily accessible for partitioning and labeling. -
Add Demographic Labels
A custom function,add_demographic_labels
, categorizes each data point based on two attributes:Male
andYoung
. The demographic labels are:- 0: Not Male & Not Young
- 1: Not Male & Young
- 2: Male & Not Young
- 3: Male & Young
These labels enable the analysis of model performance across demographic groups.
-
Select Partitioning Strategy
The dataset is partitioned based on the specifieddistribution
:- IID Partitioning: Evenly splits the dataset across all clients.
- Non-IID Partitioning: Uses the Visual Dirichlet Partitioner to create skewed distributions based on demographic labels, with parameters such as
alpha
controlling non-uniformity.
-
Partition the Dataset
The dataset is divided intonum_partitions
, and the specifiedpartition_id
determines the data subset for a client usingFedDataset
class. -
Split into Training and Validation
Each client’s data subset is split into 80% training and 20% validation data. -
Apply Transformations
PyTorch transformations, such as normalization, are applied to prepare the data for training. -
Create DataLoaders
The function returns three DataLoader objects:trainloader
: For client-specific training data.valloader
: For local validation.testloader
: For global testing across all clients.
The mobilenetv2.py
script defines a MobileNetV2 model:
- Feature Extractor: Pre-trained and frozen to reduce computational overhead.
- Classifier Head: Customizable for the specific task of classifying CelebA attributes.
This step involves setting up the core components of the federated learning framework. It defines the interactions between clients and the server, implements the Federated Averaging strategy, and integrates task-specific logic for model training and evaluation.
The client script handles the following key responsibilities:
- Local Training: Each client trains its model on its partitioned dataset, leveraging the task-specific training and testing functions.
- Communication: The client sends its locally trained model updates (weights) to the server and receives the global model from the server after aggregation, and using this model for local training.
The server script manages the federated learning process by:
- Model Aggregation: Uses the Federated Averaging strategy to aggregate model updates from all clients.
- Coordination: Orchestrates communication rounds, ensuring proper synchronization between clients and the global model.
- Evaluation: Includes evaluation function to allow the server to assess the aggregated model’s performance on a global test set after each communication round.
The Federated Averaging strategy script implements the core aggregation logic. Key features:
- Model Aggregation: Combines model weights from clients based on their dataset sizes.
- Modified Evaluation Function: Enhances the standard FedAvg algorithm by enabling the server to evaluate the model’s performance after every communication round, providing real-time feedback on convergence.
This script contains essential functions for training, testing, and evaluating the model, including:
- Training (
train_fn
): Handles local model training on client devices. - Testing (
test_fn
): Evaluates the model on the global test set and client-specific validation sets. - Metrics Calculation: Computes and returns metrics such as accuracy, precision, recall, and F1-score. Results from each round are saved in the
results/
directory for analysis.
The main.py
script serves as the entry point to orchestrate the entire federated learning experiment. It ensures proper coordination between the server and clients while handling the configuration of the training process and evaluation.
-
Setting Up the Server and Clients
The script initializes:- The server application (
ServerApp
) using theserver_fn
defined inserver.py
, which manages global model aggregation and evaluation. - The client application (
ClientApp
) using theclient_fn
defined inclient.py
, which handles local training and communication with the server.
- The server application (
-
Configuring the Number of Communication Rounds
The number of participating clients (num_supernodes
) and backend settings are loaded dynamically from the configuration file (configs/config.py
). -
Initiating Training and Evaluation
Therun_simulation
function coordinates the communication rounds between the server and clients:- Starts the training process across multiple clients.
- Aggregates local model updates at the server.
- Evaluates the global model after each communication round.
The evaluation process assesses the performance of the global model after each communication round. Results are saved in the results/
directory for detailed analysis and visualization.
-
Classification Report:
Includes precision, recall, and F1-score for each class, providing insights into the model's performance across different demographic labels. -
Confusion Matrix:
Captures the distribution of correct and incorrect predictions, helping identify patterns and areas of improvement. -
Learning Curves:
Tracks training and validation accuracy and loss over communication rounds, providing a clear view of the model’s convergence behavior.
A Jupyter notebook, Federated_Learning_with_Flower_and_CelebA.ipynb
, is included for in-depth evaluation and visualization. This notebook provides:
- Visual representations of learning curves.
- Insights into class-wise performance using classification reports and confusion matrices.
- A summary of key metrics over the federated learning process.
-
Clone the repository:
git clone https://github.com/TianyueChu/FedFlower.git cd FedFlower
-
Install the required dependencies:
pip install -r requirements.txt
-
Configure the simulation:
Edit the configs/config.py file to set up the desired number of clients (NUM_PARTITIONS), backend configurations (backend_config), and other parameters.
-
Run the federated learning experiment:
python main.py
Run the federated learning in real-world settings using gRPC for client-server communication:
- for
server.py
, comment out the lineserver = fl.server.start_server(server_address="0.0.0.0:8080", config=ServerConfig(num_rounds=10), strategy=strategy)
- for
client.py
, change the lines from
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
to
partition_id = get_or_create_partition_id(client_id)
num_partitions = cfg.NUM_PARTITIONS
Comment out the lines
initialize_partition_file(cfg.NUM_PARTITIONS)
flwr.client.start_client(server_address="127.0.0.1:8080", client_fn=client_fn)
-
Start the server and client applications in separate terminals:
-
Run the server and client applications:
python server.py
python client.py
- Explore the saved evaluation metrics in the
results/server/
directory. - Use the Jupyter notebook
Federated_Learning_with_Flower_and_CelebA.ipynb
for detailed analysis and visualization.
This project is licensed under the MIT License - see the LICENSE file for details.
- Flower: https://flower.dev/
- PyTorch: https://pytorch.org/
- CelebA Dataset from MMLab