diff --git a/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_2.py b/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_2.py index af2870b2c3..1172f4107f 100644 --- a/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_2.py +++ b/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_2.py @@ -35,12 +35,18 @@ random_seed = 5495300300540669060 -g_device = torch.Generator(device="cuda") -# Uncomment the line below to use g_cpu if not using cuda -# g_device = torch.Generator() # noqa: E800 -# NOTE: remove below to stop repeatable runs +# Fixing the seed for result repeatation: remove below to stop repeatable runs +# ---------------------------------- +# Fixing the seed for result reproducibility +random_seed = 5495300300540669060 +# Determine the device to use (CUDA if available, otherwise CPU) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# Create a generator for the specified device +g_device = torch.Generator(device=device) +# Set the seed for the generator g_device.manual_seed(random_seed) -print(f"\n\nWe are using seed: {random_seed}") +print(f"\n\nWe are using seed: {random_seed} on device: {device}") +# ---------------------------------- mnist_train = torchvision.datasets.MNIST( "files/", @@ -134,6 +140,33 @@ def populate_model_params_and_gradients( # use only the first state params.grad = states_for_gradients[0][name] + def __getstate__(self): + state = self.__dict__.copy() + # Remove the optimizer and privacy engine from the state + state['global_optimizer'] = None + state['privacy_engine'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + # Recreate the optimizer and privacy engine + self.global_optimizer = torch.optim.SGD( + params=self.global_model.parameters(), lr=1.0 + ) + self.privacy_engine = PrivacyEngine() + if self.dp_params is not None: + ( + self.global_model, + self.global_optimizer, + self.global_data_loader, + ) = self.privacy_engine.make_private( + module=self.global_model, + optimizer=self.global_optimizer, + data_loader=self.global_data_loader, + noise_multiplier=self.dp_params["noise_multiplier"], + max_grad_norm=self.dp_params["clip_norm"], + ) + class Net(nn.Module): def __init__(self): @@ -202,7 +235,7 @@ def FedAvg(models, global_model_tools, previous_global_state, dp_params): # NOQ # global_model_tools.global_model state dict new_model.load_state_dict( { - key[8:]: value + key.replace('_module.', ''): value for key, value in global_model_tools.global_model.state_dict().items() } ) @@ -223,7 +256,7 @@ def inference(network, test_loader, device): correct += pred.eq(target.data.view_as(pred)).sum() test_loss /= len(test_loader.dataset) print( - "\nTest set: Avg. loss: {test_loss:.4f}," + f"\nTest set: Avg. loss: {test_loss:.4f}," f" Accuracy: {correct}/{len(test_loader.dataset)}" f" ({(100.0 * correct / len(test_loader.dataset)):.0f})\n" ) @@ -338,12 +371,6 @@ def __init__( self.dp_params = config["differential_privacy"] print(f"Here are dp_params: {self.dp_params}") validate_dp_params(self.dp_params) - self.global_model_tools = GlobalModelTools( - global_model=self.global_model, - example_model_state=self.model.state_dict(), - collaborator_names=self.collaborator_names, - dp_params=self.dp_params, - ) @aggregator def start(self): @@ -596,9 +623,10 @@ def end(self): else: device = torch.device("cpu") + # Initialize the model + initial_model = Net() + # Setup participants - # Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU - aggregator = Aggregator(num_gpus=0.09) # Setup collaborators with private attributes collaborator_names = [ @@ -614,6 +642,28 @@ def end(self): "Guadalajara", ] + # Initialize the aggregator with the callable + # Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU + + def callable_to_initialize_aggregator_private_attributes(): + # Load the configuration file to get dp_params + with open(args.config_path, "rb") as _file: + config = yaml.safe_load(_file) + dp_params = config["differential_privacy"] + + return {'global_model_tools': GlobalModelTools( + global_model=initial_model, + example_model_state=initial_model.state_dict(), + collaborator_names=collaborator_names, + dp_params=dp_params + )} + + aggregator = Aggregator( + name="agg", + private_attributes_callable= callable_to_initialize_aggregator_private_attributes, + num_gpus=0.0 + ) + def callable_to_initialize_collaborator_private_attributes( index, n_collaborators, batch_size, train_dataset, test_dataset ): @@ -641,7 +691,7 @@ def callable_to_initialize_collaborator_private_attributes( private_attributes_callable=callable_to_initialize_collaborator_private_attributes, # Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU num_cpus=0.0, - num_gpus=0.09, # Assuming GPU(s) is available in the machine + num_gpus=0.0, # Assuming GPU(s) is available in the machine index=idx, n_collaborators=len(collaborator_names), batch_size=batch_size_train, @@ -656,7 +706,6 @@ def callable_to_initialize_collaborator_private_attributes( print(f"Local runtime collaborators = {local_runtime.collaborators}") best_model = None - initial_model = Net() top_model_accuracy = 0 total_rounds = 10 @@ -670,4 +719,4 @@ def callable_to_initialize_collaborator_private_attributes( clip_test=args.clip_test, ) flflow.runtime = local_runtime - flflow.run() + flflow.run() \ No newline at end of file diff --git a/openfl-tutorials/experimental/workflow/Global_DP/requirements_global_dp.txt b/openfl-tutorials/experimental/workflow/Global_DP/requirements_global_dp.txt index f5118684aa..d145b7873d 100644 --- a/openfl-tutorials/experimental/workflow/Global_DP/requirements_global_dp.txt +++ b/openfl-tutorials/experimental/workflow/Global_DP/requirements_global_dp.txt @@ -1,9 +1,9 @@ cloudpickle -matplotlib==3.6.0 -numpy==1.23.3 -opacus==1.5.1 -pillow==10.3.0 -pyyaml==6.0 -torch==2.2.0 -torchaudio==2.2.0 -torchvision==0.17.0 +matplotlib==3.10.0 +numpy==1.26.4 +opacus==1.5.2 +pillow==11.0.0 +pyyaml==6.0.2 +torch==2.5.1 +torchaudio==2.5.1 +torchvision==0.20.1