Skip to content

Commit

Permalink
add better documantaiton to the fedprox optimizer
Browse files Browse the repository at this point in the history
Signed-off-by: Buchnik, Yehonatan <[email protected]>
  • Loading branch information
yontyon committed Dec 12, 2024
1 parent 580d129 commit c215e2f
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions openfl/utilities/optimizers/keras/fedprox.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def build(self, variables):
"""Initialize optimizer variables.
Args:
var_list: list of model variables to build FedProx variables on.
variables (list): List of model variables to build FedProx variables on.
"""
if self.built:
return
Expand All @@ -55,15 +55,33 @@ def build(self, variables):
)

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
""" Update step given gradient and the associated model variable.
In the update_step method, variable is updated using the gradient and the proximal term (mu).
The proximal term helps to regularize the update by considering the difference between
the current value of variable and its initial value (vstar), which was stored during the build method.
Args:
gradient (tf.Tensor): The gradient tensor for the variable.
variable (tf.Variable): The model variable to be updated.
learning_rate (float): The learning rate for the update step.
"""
lr_t = tf.cast(learning_rate, variable.dtype)
mu_t = tf.cast(self.mu, variable.dtype)
gradient_t = tf.cast(gradient, variable.dtype)
# Get the corresponding vstar for the current variable
vstar = self.vstars[self._get_variable_index(variable)]

# Update the variable using the gradient and the proximal term
self.assign_sub(variable, lr_t * (gradient_t + mu_t * (variable - vstar)))

def get_config(self):
"""Return the config of the optimizer.
An optimizer config is a Python dictionary (serializable)
containing the configuration of an optimizer.
The same optimizer can be reinstantiated later
(without any saved state) from this configuration.
Returns:
dict: The optimizer configuration.
"""
config = super().get_config()
config.update(
{
Expand Down

0 comments on commit c215e2f

Please sign in to comment.