Skip to content

Commit

Permalink
revert one loss-function rename
Browse files Browse the repository at this point in the history
To avoid breaking compatibility with model checkpoint files from before the previous commit, such as the pre-trained model from the CoverHunter authors.
  • Loading branch information
alanngnet committed Nov 1, 2024
1 parent 4f87bad commit d74466f
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ def __init__(self, hp: Dict) -> None:
self._pool_layer = AttentiveStatisticsPooling(
hp["embed_dim"], output_channels=hp["embed_dim"],
)
self._foc_layer = torch.nn.Linear(
# _ce_layer should be _foc_layer but retaining historical name to avoid
# breaking compatibility with legacy checkpoint files such as pre-trained models
# see CoverHunter paper about their switch from cross-entropy to focal loss
self._ce_layer = torch.nn.Linear(
hp["embed_dim"], hp["foc"]["output_dims"], bias=False,
)

Expand Down Expand Up @@ -324,7 +327,7 @@ def compute_loss(

# FocalLoss
f_i = self._bottleneck(f_t)
foc_pred = self._foc_layer(f_i)
foc_pred = self._ce_layer(f_i)
foc_loss = self._foc_loss(foc_pred, label)
loss = foc_loss * self._hp["foc"]["weight"]
loss_dict = {"foc_loss": foc_loss}
Expand All @@ -346,7 +349,7 @@ def compute_loss(
def inference(self, feat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
embed = self.forward(feat)
embed_foc = self._foc_layer(embed)
embed_foc = self._ce_layer(embed)
return embed, embed_foc

# @torch.jit.export
Expand Down

0 comments on commit d74466f

Please sign in to comment.