Skip to content

Commit

Permalink
removed duplicate code
Browse files Browse the repository at this point in the history
Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh committed Dec 2, 2024
1 parent 5e7eac6 commit 2f94d92
Showing 1 changed file with 3 additions and 12 deletions.
15 changes: 3 additions & 12 deletions openfl/federated/task/runner_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,16 +302,13 @@ def _get_weights_dict(obj, suffix=""):
weights_dict (dict): The weight dictionary.
"""
weights_dict = {}
weight_names = KerasTaskRunner._get_weights_names(obj)
if isinstance(obj, ke.optimizers.Optimizer):
weight_names = [weight.name for weight in obj.variables]
weights_dict = {
weight_names[i] + suffix: weight.numpy()
for i, weight in enumerate(copy.deepcopy(obj.variables))
}
else:
weight_names = [
layer.name + "/" + weight.name for layer in obj.layers for weight in layer.weights
]
weight_name_index = 0
for layer in obj.layers:
if weight_name_index < len(weight_names) and len(layer.get_weights()) > 0:
Expand All @@ -329,14 +326,8 @@ def _set_weights_dict(obj, weights_dict):
the weights.
weights_dict (dict): The weight dictionary.
"""
if isinstance(obj, ke.optimizers.Optimizer):
weight_names = [weight.name for weight in obj.variables]
weight_values = [weights_dict[name] for name in weight_names]
else:
weight_names = [
layer.name + "/" + weight.name for layer in obj.layers for weight in layer.weights
]
weight_values = [weights_dict[name] for name in weight_names]
weight_names = KerasTaskRunner._get_weights_names(obj)
weight_values = [weights_dict[name] for name in weight_names]
obj.set_weights(weight_values)

def get_tensor_dict(self, with_opt_vars, suffix=""):
Expand Down

0 comments on commit 2f94d92

Please sign in to comment.