diff --git a/openfl/utilities/optimizers/keras/fedprox.py b/openfl/utilities/optimizers/keras/fedprox.py index 64d467cdc8..3e83b1a707 100644 --- a/openfl/utilities/optimizers/keras/fedprox.py +++ b/openfl/utilities/optimizers/keras/fedprox.py @@ -41,7 +41,7 @@ def build(self, variables): """Initialize optimizer variables. Args: - var_list: list of model variables to build FedProx variables on. + variables (list): List of model variables to build FedProx variables on. """ if self.built: return @@ -55,15 +55,33 @@ def build(self, variables): ) def update_step(self, gradient, variable, learning_rate): - """Update step given gradient and the associated model variable.""" + """ Update step given gradient and the associated model variable. + In the update_step method, variable is updated using the gradient and the proximal term (mu). + The proximal term helps to regularize the update by considering the difference between + the current value of variable and its initial value (vstar), which was stored during the build method. + Args: + gradient (tf.Tensor): The gradient tensor for the variable. + variable (tf.Variable): The model variable to be updated. + learning_rate (float): The learning rate for the update step. + """ lr_t = tf.cast(learning_rate, variable.dtype) mu_t = tf.cast(self.mu, variable.dtype) gradient_t = tf.cast(gradient, variable.dtype) + # Get the corresponding vstar for the current variable vstar = self.vstars[self._get_variable_index(variable)] + # Update the variable using the gradient and the proximal term self.assign_sub(variable, lr_t * (gradient_t + mu_t * (variable - vstar))) def get_config(self): + """Return the config of the optimizer. + An optimizer config is a Python dictionary (serializable) + containing the configuration of an optimizer. + The same optimizer can be reinstantiated later + (without any saved state) from this configuration. + Returns: + dict: The optimizer configuration. + """ config = super().get_config() config.update( {