diff --git a/fanan/modules/attentions/__init__.py b/fanan/modules/attentions/__init__.py deleted file mode 100644 index 4838032..0000000 --- a/fanan/modules/attentions/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Any, Dict - -_ATTENTIONS: Dict[str, Any] = {} - - -def register_attention_fn(fn): - _ATTENTIONS[fn.__name__.lower()] = fn - return fn - - -from fanan.modules.attentions.self_attention import * # noqa: E402, F403 - - -def get_attention_fn(name: str): - assert name in _ATTENTIONS, f"attention fn {name=} is not supported. Available attentions: {_ATTENTIONS.keys()}" - return _ATTENTIONS[name.lower()] diff --git a/fanan/modules/attentions/self_attention.py b/fanan/modules/attentions/self_attention.py deleted file mode 100644 index 270fa2c..0000000 --- a/fanan/modules/attentions/self_attention.py +++ /dev/null @@ -1,39 +0,0 @@ -import math - -import jax -import jax.numpy as jnp -from beartype import beartype as typechecker -from jaxtyping import Array, Float, jaxtyped - -from fanan.modules.attentions import register_attention_fn - - -@register_attention_fn -@jaxtyped(typechecker=typechecker) -def self_attention( - query: Float[Array, "batch_size sequence_length n_heads head_dim"], - value: Float[Array, "batch_size sequence_length n_heads head_dim"], - key: Float[Array, "batch_size sequence_length n_heads head_dim"], - mask: jax.Array = None, -) -> Float[Array, "batch_size sequence_length n_heads head_dim"]: - """Self attention mechanism.""" - kv_heads = key.shape[-2] - q_heads, head_dim = query.shape[-2], query.shape[-1] - - if q_heads != kv_heads: - assert q_heads > kv_heads - tile_factor = q_heads // kv_heads - key = jnp.repeat(key, tile_factor, axis=-2) - value = jnp.repeat(value, tile_factor, axis=-2) - - scale = float(1 / math.sqrt(head_dim)) - - attention_logits = jnp.einsum("bthd,bThd->bhtT", query, key) - attention_logits = (attention_logits * scale).astype(query.dtype) - - attention_weights = jax.nn.softmax(attention_logits, axis=-1) - attention_weights = attention_weights.astype(value.dtype) - - attention_vec = jnp.einsum("bhtT,bThd->bthd", attention_weights, value) - - return attention_vec