Skip to content

Commit

Permalink
fix(configs): optimizers states saving and loading
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent cdbf88f commit 1fcafb5
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions neuralnetlib/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,24 @@ def update(self, layer_index: int, weights: np.ndarray, weights_grad: np.ndarray
bias += self.velocity_b

def get_config(self) -> dict:
return {"name": self.__class__.__name__, "learning_rate": self.learning_rate, "momentum": self.momentum,
"velocity": self.velocity}

def __str__(self):
return f"{self.__class__.__name__}(learning_rate={self.learning_rate}, momentum={self.momentum})"
return {
"name": self.__class__.__name__,
"learning_rate": self.learning_rate,
"momentum": self.momentum,
"velocity_w": dict_with_ndarray_to_dict_with_list(self.velocity_w) if hasattr(self, 'velocity_w') else None,
"velocity_b": dict_with_ndarray_to_dict_with_list(self.velocity_b) if hasattr(self, 'velocity_b') else None
}

@staticmethod
def from_config(config: dict):
return Momentum(config['learning_rate'], config['momentum'])
optimizer = Momentum(config['learning_rate'], config['momentum'])
if config.get('velocity_w'):
optimizer.velocity_w = dict_with_list_to_dict_with_ndarray(config['velocity_w'])
optimizer.velocity_b = dict_with_list_to_dict_with_ndarray(config['velocity_b'])
return optimizer

def __str__(self):
return f"{self.__class__.__name__}(learning_rate={self.learning_rate}, momentum={self.momentum})"


class RMSprop(Optimizer):
Expand Down Expand Up @@ -117,12 +126,22 @@ def update(self, layer_index: int, weights: np.ndarray, weights_grad: np.ndarray
(np.sqrt(self.sq_grads_b) + self.epsilon)

def get_config(self) -> dict:
return {"name": self.__class__.__name__, "learning_rate": self.learning_rate, "rho": self.rho,
"epsilon": self.epsilon, "sq_grads": self.sq_grads}
return {
"name": self.__class__.__name__,
"learning_rate": self.learning_rate,
"rho": self.rho,
"epsilon": self.epsilon,
"sq_grads_w": dict_with_ndarray_to_dict_with_list(self.sq_grads_w) if hasattr(self, 'sq_grads_w') else None,
"sq_grads_b": dict_with_ndarray_to_dict_with_list(self.sq_grads_b) if hasattr(self, 'sq_grads_b') else None
}

@staticmethod
def from_config(config: dict):
return RMSprop(config['learning_rate'], config['rho'], config['epsilon'])
optimizer = RMSprop(config['learning_rate'], config['rho'], config['epsilon'])
if config.get('sq_grads_w'):
optimizer.sq_grads_w = dict_with_list_to_dict_with_ndarray(config['sq_grads_w'])
optimizer.sq_grads_b = dict_with_list_to_dict_with_ndarray(config['sq_grads_b'])
return optimizer

def __str__(self):
return f"{self.__class__.__name__}(learning_rate={self.learning_rate}, rho={self.rho}, epsilon={self.epsilon})"
Expand Down

0 comments on commit 1fcafb5

Please sign in to comment.