Skip to content

Commit

Permalink
Fix trailing whitespace and end-of-file newlines
Browse files Browse the repository at this point in the history
Signed-off-by: Agniva Chowdhury <[email protected]>
  • Loading branch information
agnivac123 committed Dec 23, 2024
1 parent b8cf66e commit 3ecc561
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@

# Fixing the seed for result repeatation: remove below to stop repeatable runs
# ----------------------------------
# Fixing the seed for result reproducibility
random_seed = 5495300300540669060
g_device = torch.Generator(device="cuda")
# Uncomment the line below to use g_cpu if not using cuda
# g_device = torch.Generator() # noqa: E800
# NOTE: remove below to stop repeatable runs
# Determine the device to use (CUDA if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create a generator for the specified device
g_device = torch.Generator(device=device)
# Set the seed for the generator
g_device.manual_seed(random_seed)
print(f"\n\nWe are using seed: {random_seed}")
print(f"\n\nWe are using seed: {random_seed} on device: {device}")
# ----------------------------------

# Loading torchvision MNIST datasets
Expand Down Expand Up @@ -152,6 +154,11 @@ def FedAvg(models, previous_global_model=None, dp_params=None): # NOQA: N802
state_dict[key] = np.sum(
np.array([state[key] for state in state_dicts], dtype=object), axis=0
) / len(models)
# Convert numpy arrays to torch tensors
if isinstance(state_dict[key], np.ndarray):
if state_dict[key].dtype == np.object_:
state_dict[key] = state_dict[key].astype(np.float32)
state_dict[key] = torch.tensor(state_dict[key])
new_model.load_state_dict(state_dict)
return new_model

Expand Down Expand Up @@ -619,7 +626,7 @@ def end(self):

# Setup participants
# Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU
aggregator = Aggregator(num_gpus=0.09)
aggregator = Aggregator(num_gpus=0.0)

# Collaborator names
collaborator_names = [
Expand Down Expand Up @@ -662,7 +669,7 @@ def callable_to_initialize_collaborator_private_attributes(
private_attributes_callable=callable_to_initialize_collaborator_private_attributes,
# Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU
num_cpus=0.0,
num_gpus=0.09, # Assuming GPU(s) is available in the machine
num_gpus=0.0, # Assuming GPU(s) is available in the machine
index=idx,
n_collaborators=len(collaborator_names),
batch_size=batch_size_train,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def end(self):

# Initialize the aggregator with the callable
# Set `num_gpus=0.09` to `num_gpus=0.0` in order to run this tutorial on CPU

def callable_to_initialize_aggregator_private_attributes():
# Load the configuration file to get dp_params
with open(args.config_path, "rb") as _file:
Expand All @@ -657,7 +657,7 @@ def callable_to_initialize_aggregator_private_attributes():
collaborator_names=collaborator_names,
dp_params=dp_params
)}

aggregator = Aggregator(
name="agg",
private_attributes_callable= callable_to_initialize_aggregator_private_attributes,
Expand Down Expand Up @@ -719,4 +719,4 @@ def callable_to_initialize_collaborator_private_attributes(
clip_test=args.clip_test,
)
flflow.runtime = local_runtime
flflow.run()
flflow.run()

0 comments on commit 3ecc561

Please sign in to comment.