Skip to content

Commit

Permalink
remove calls for legacy optimizer
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Jun 7, 2024
1 parent bea8729 commit 05014d2
Showing 1 changed file with 7 additions and 28 deletions.
35 changes: 7 additions & 28 deletions openfl/federated/task/runner_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,7 @@ def _get_weights_names(obj, with_opt_vars):
The weight name list
"""
if with_opt_vars:
# When acquiring optimizer weights, check optimizer version.
# Current optimizer does not use 'weights' attributes
if 'legacy' in obj.__class__.__module__:
weight_names = [weight.name for weight in obj.weights]
else:
weight_names = [weight.name for weight in obj.variables]
weight_names = [weight.name for weight in obj.variables]

weight_names = [weight.name for weight in obj.weights]
return weight_names
Expand All @@ -287,14 +282,8 @@ def _get_weights_dict(obj, suffix='', with_opt_vars=False):

weights_dict = {}
if with_opt_vars:
# When acquiring optimizer weights, check optimizer version.
# Current optimizer does not use 'weights' or '.get_weights()' attributes
if 'legacy' in obj.__class__.__module__:
weight_names = [weight.name for weight in obj.weights]
weight_values = obj.get_weights()
else:
weight_names = [weight.name for weight in obj.variables]
weight_values = [weight.numpy() for weight in obj.variables]
weight_names = [weight.name for weight in obj.variables]
weight_values = [weight.numpy() for weight in obj.variables]
else:
weight_names = [weight.name for weight in obj.weights]
weight_values = obj.get_weights()
Expand All @@ -319,12 +308,7 @@ def _set_weights_dict(obj, weights_dict, with_opt_vars=False):
"""

if with_opt_vars:
# When acquiring optimizer weights, check optimizer version.
# Current optimizer does not use 'weights' attributes
if 'legacy' in obj.__class__.__module__:
weight_names = [weight.name for weight in obj.weights]
else:
weight_names = [weight.name for weight in obj.variables]
weight_names = [weight.name for weight in obj.variables]
else:
weight_names = [weight.name for weight in obj.weights]

Expand Down Expand Up @@ -383,15 +367,10 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars):
model_weights_dict = {
name: tensor_dict[name] for name in model_weight_names
}
if 'legacy' in self.model.optimizer.__class__.__module__:
opt_weight_names = [
weight.name for weight in self.model.optimizer.weights
]
else:
opt_weight_names = [
weight.name for weight in self.model.optimizer.variables
]

opt_weight_names = [
weight.name for weight in self.model.optimizer.variables
]
opt_weights_dict = {
name: tensor_dict[name] for name in opt_weight_names
}
Expand Down

0 comments on commit 05014d2

Please sign in to comment.