Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Feb 7, 2025
1 parent d05f862 commit bd43b8e
Showing 1 changed file with 3 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,6 @@ def __init__(
auxiliary=True,
non_critical_for_restore=True,
)
self.prior_running_mean_momentum = config.typed_value("prior_running_mean_momentum", None)
self.prior_running_mean_per_layer = config.bool("prior_running_mean_per_layer", False)
self.prior_running_mean = None # in std prob, if set
if self.prior_running_mean_momentum is not None:
self.prior_running_mean = rf.Parameter(
[self.wb_target_dim], auxiliary=True, initial=1.0 / self.wb_target_dim.dimension
)
if self.prior_running_mean_per_layer:
for i in enc_aux_logits:
setattr(
self,
f"prior_running_mean_{i}",
rf.Parameter([self.wb_target_dim], auxiliary=True, initial=1.0 / self.wb_target_dim.dimension),
)

if target_dim.vocab and not wb_target_dim.vocab:
from returnn.datasets.util.vocabulary import Vocabulary
Expand Down Expand Up @@ -354,11 +340,9 @@ def aux_logits_from_collected_outputs(self, aux_layer: int, collected_outputs: D
aux_logits = linear(collected_outputs[str(aux_layer - 1)])
return aux_logits

def log_probs_wb_from_logits(self, logits: Tensor, *, aux_layer: Optional[int] = None) -> Tensor:
def log_probs_wb_from_logits(self, logits: Tensor) -> Tensor:
"""
:param logits: incl blank
:param aux_layer: whether the logits come from some intermediate aux layer.
That might influence the prior.
:return: log probs with blank from logits (wb_target_dim)
If out_blank_separated, we use a separate sigmoid for the blank.
Also, potentially adds label smoothing on the gradients.
Expand All @@ -376,7 +360,7 @@ def log_probs_wb_from_logits(self, logits: Tensor, *, aux_layer: Optional[int] =
logits, axis=self.wb_target_dim, out_dims=[self.target_dim, dummy_blank_feat_dim]
)
log_probs_wo_blank = rf.log_softmax(logits_wo_blank, axis=self.target_dim)
log_probs_wo_blank = self._maybe_apply_on_log_probs(log_probs_wo_blank, aux_layer=aux_layer)
log_probs_wo_blank = self._maybe_apply_on_log_probs(log_probs_wo_blank)
if self.blank_logit_shift:
logits_blank += self.blank_logit_shift
log_probs_blank = rf.log_sigmoid(logits_blank)
Expand All @@ -388,22 +372,7 @@ def log_probs_wb_from_logits(self, logits: Tensor, *, aux_layer: Optional[int] =
)
log_probs.feature_dim = self.wb_target_dim

prior_running_mean = None
if self.prior_running_mean_momentum is not None:
prior_running_mean = self.prior_running_mean
if self.prior_running_mean_per_layer and aux_layer is not None:
prior_running_mean = getattr(self, f"prior_running_mean_{aux_layer}")

def _update_running_stats():
batch_prior = rf.reduce_mean(
rf.exp(log_probs), axis=[d for d in log_probs.dims if d != self.wb_target_dim]
)
assert batch_prior.dims == (self.wb_target_dim,)
prior_running_mean.assign_add(self.prior_running_mean_momentum * (batch_prior - prior_running_mean))

rf.cond(rf.get_run_ctx().train_flag, _update_running_stats, lambda: None)

log_probs = self._maybe_apply_on_log_probs(log_probs, aux_layer=aux_layer)
log_probs = self._maybe_apply_on_log_probs(log_probs)
if self.ctc_am_scale == 1 and self.ctc_prior_scale == 0: # fast path
return log_probs
log_probs_am = log_probs
Expand Down Expand Up @@ -443,10 +412,6 @@ def _update_running_stats():
elif self.ctc_prior_type == "static":
log_prob_prior = self.static_prior
assert log_prob_prior.dims == (self.wb_target_dim,)
elif self.ctc_prior_type == "running_mean":
assert prior_running_mean is not None
log_prob_prior = rf.safe_log(prior_running_mean)
assert log_prob_prior.dims == (self.wb_target_dim,)
else:
raise ValueError(f"invalid ctc_prior_type {self.ctc_prior_type!r}")
log_probs -= log_prob_prior * self.ctc_prior_scale
Expand Down

0 comments on commit bd43b8e

Please sign in to comment.