Skip to content

Commit

Permalink
fix attention after the error introduced in the latest PR
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 10, 2024
1 parent a129c7b commit d07404b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions serket/_src/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def dot_product_attention(
- https://keras.io/api/layers/attention_layers/multi_head_attention/
- https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html
"""
*_, num_heads, k_depth = k_heads.shape
*_other, _length, _num_heads, k_depth = k_heads.shape
logits = jnp.einsum("...qhd,...khd->...hqk", q_heads, k_heads)
logits /= jnp.sqrt(k_depth // num_heads)
logits /= jnp.sqrt(k_depth)
min_num = jnp.finfo(logits.dtype).min
logits = logits if mask is None else jnp.where(mask, logits, min_num)
weight = jax.nn.softmax(logits)
Expand Down

0 comments on commit d07404b

Please sign in to comment.