Skip to content

Commit

Permalink
fix(configs): losses states saving and loading
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent 725fcac commit cdbf88f
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion neuralnetlib/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ def derivative(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
def __str__(self):
return f"CrossEntropyWithLabelSmoothing(label_smoothing={self.label_smoothing})"

def get_config(self) -> dict:
return {
"name": self.__class__.__name__,
"label_smoothing": self.label_smoothing
}


class Wasserstein(LossFunction):
def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
Expand Down Expand Up @@ -355,6 +361,15 @@ def derivative(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:

return grad

def get_config(self) -> dict:
return {
"name": self.__class__.__name__,
"gamma": self.gamma,
"alpha": self.alpha,
"scale": self.scale
}


class MultiLabelBCELoss(LossFunction):
def __init__(self, pos_weight: float = 1.0):
super().__init__()
Expand All @@ -379,6 +394,12 @@ def derivative(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
grad = np.clip(grad, -10, 10)

return grad / y_true.size

def get_config(self) -> dict:
return {
"name": self.__class__.__name__,
"pos_weight": self.pos_weight
}


class AsymmetricLoss(LossFunction):
Expand Down Expand Up @@ -426,4 +447,12 @@ def derivative(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
d_neg_focusing * d_neg_bce
)

return gradient / y_true.shape[0]
return gradient / y_true.shape[0]

def get_config(self) -> dict:
return {
"name": self.__class__.__name__,
"gamma_pos": self.gamma_pos,
"gamma_neg": self.gamma_neg,
"clip": self.clip
}

0 comments on commit cdbf88f

Please sign in to comment.