Skip to content

Commit

Permalink
fix: MultiHeadDotProductAttention and optax ctc_loss changes
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Dec 22, 2024
1 parent d21d820 commit 785d82b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Encoder1DBlock(nn.Module):
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.SelfAttention(
y = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
Expand All @@ -89,7 +89,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
x = x + y
else:
y = x
y = nn.SelfAttention(
y = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,9 @@ def __call__(self, inputs, paddings, train):
mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32)

inputs = LayerNorm(dim=config.encoder_dim)(inputs)

attention_fn = functools.partial(
dot_product_attention, temperature=config.attention_temperature)
result = nn.SelfAttention(
result = nn.MultiHeadDotProductAttention(
num_heads=config.num_attention_heads,
qkv_features=config.encoder_dim,
decode=False,
Expand All @@ -410,7 +409,8 @@ def __call__(self, inputs, paddings, train):
broadcast_dropout=False,
attention_fn=attention_fn,
dropout_rate=config.attention_dropout_rate,
deterministic=not train)(inputs, attention_mask)
deterministic=not train)(
inputs_q=inputs, mask=attention_mask)

if config.attention_residual_dropout_rate is None:
attention_residual_dropout_rate = 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,12 @@ def ctc_loss(self,
labels: spec.Tensor,
label_paddings: spec.Tensor,
blank_id: int = 0) -> spec.Tensor:
return optax.ctc_loss(logits,
logit_paddings,
labels,
label_paddings,
blank_id)
return optax.ctc_loss(
logits=logits,
logit_paddings=logit_paddings,
labels=labels,
label_paddings=label_paddings,
blank_id=blank_id)

# Adapted from lingvo's greedy decoding logic here:
# https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138.
Expand Down

0 comments on commit 785d82b

Please sign in to comment.