From 3ecc5610383cf8c3195d633e6d35cbd3e6363821 Mon Sep 17 00:00:00 2001 From: Agniva Chowdhury Date: Mon, 23 Dec 2024 09:47:31 -0800 Subject: [PATCH] Fix trailing whitespace and end-of-file newlines Signed-off-by: Agniva Chowdhury --- ...rkflow_Interface_Mnist_Implementation_1.py | 21 ++++++++++++------- ...rkflow_Interface_Mnist_Implementation_2.py | 6 +++--- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_1.py b/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_1.py index 1a2035e549..c80af36571 100644 --- a/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_1.py +++ b/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_1.py @@ -38,13 +38,15 @@ # Fixing the seed for result repeatation: remove below to stop repeatable runs # ---------------------------------- +# Fixing the seed for result reproducibility random_seed = 5495300300540669060 -g_device = torch.Generator(device="cuda") -# Uncomment the line below to use g_cpu if not using cuda -# g_device = torch.Generator() # noqa: E800 -# NOTE: remove below to stop repeatable runs +# Determine the device to use (CUDA if available, otherwise CPU) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# Create a generator for the specified device +g_device = torch.Generator(device=device) +# Set the seed for the generator g_device.manual_seed(random_seed) -print(f"\n\nWe are using seed: {random_seed}") +print(f"\n\nWe are using seed: {random_seed} on device: {device}") # ---------------------------------- # Loading torchvision MNIST datasets @@ -152,6 +154,11 @@ def FedAvg(models, previous_global_model=None, dp_params=None): # NOQA: N802 state_dict[key] = np.sum( np.array([state[key] for state in state_dicts], dtype=object), axis=0 ) / len(models) + # Convert numpy arrays to torch tensors + if isinstance(state_dict[key], np.ndarray): + if state_dict[key].dtype == np.object_: + state_dict[key] = state_dict[key].astype(np.float32) + state_dict[key] = torch.tensor(state_dict[key]) new_model.load_state_dict(state_dict) return new_model @@ -619,7 +626,7 @@ def end(self): # Setup participants # Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU - aggregator = Aggregator(num_gpus=0.09) + aggregator = Aggregator(num_gpus=0.0) # Collaborator names collaborator_names = [ @@ -662,7 +669,7 @@ def callable_to_initialize_collaborator_private_attributes( private_attributes_callable=callable_to_initialize_collaborator_private_attributes, # Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU num_cpus=0.0, - num_gpus=0.09, # Assuming GPU(s) is available in the machine + num_gpus=0.0, # Assuming GPU(s) is available in the machine index=idx, n_collaborators=len(collaborator_names), batch_size=batch_size_train, diff --git a/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_2.py b/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_2.py index 1172f4107f..722009e5b3 100644 --- a/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_2.py +++ b/openfl-tutorials/experimental/workflow/Global_DP/Workflow_Interface_Mnist_Implementation_2.py @@ -644,7 +644,7 @@ def end(self): # Initialize the aggregator with the callable # Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU - + def callable_to_initialize_aggregator_private_attributes(): # Load the configuration file to get dp_params with open(args.config_path, "rb") as _file: @@ -657,7 +657,7 @@ def callable_to_initialize_aggregator_private_attributes(): collaborator_names=collaborator_names, dp_params=dp_params )} - + aggregator = Aggregator( name="agg", private_attributes_callable= callable_to_initialize_aggregator_private_attributes, @@ -719,4 +719,4 @@ def callable_to_initialize_collaborator_private_attributes( clip_test=args.clip_test, ) flflow.runtime = local_runtime - flflow.run() \ No newline at end of file + flflow.run()