Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AmrMKayid committed May 14, 2024
1 parent 2d65221 commit d2b29ca
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 1 deletion.
28 changes: 28 additions & 0 deletions .github/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Unittests

on: [push, pull_request]

jobs:
test:
name: Run Unittests
runs-on: ubuntu-latest

steps:
- name: Checkout Repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.x

- name: Install Poetry
run: |
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python
export PATH="$HOME/.poetry/bin:$PATH"
- name: Install Dependencies
run: poetry install

- name: Run Unittests
run: poetry run python -m unittest discover -s tests -v
16 changes: 16 additions & 0 deletions fanan/modules/attentions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Any, Dict

_ATTENTIONS: Dict[str, Any] = {}


def register_attention_fn(fn):
_ATTENTIONS[fn.__name__.lower()] = fn
return fn


from fanan.modules.attentions.self_attention import * # noqa: E402, F403


def get_attention_fn(name: str):
assert name in _ATTENTIONS, f"attention fn {name=} is not supported. Available attentions: {_ATTENTIONS.keys()}"
return _ATTENTIONS[name.lower()]
39 changes: 39 additions & 0 deletions fanan/modules/attentions/self_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import math

import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Array, Float, jaxtyped

from fanan.modules.attentions import register_attention_fn


@register_attention_fn
@jaxtyped(typechecker=typechecker)
def self_attention(
query: Float[Array, "batch_size sequence_length n_heads head_dim"],
value: Float[Array, "batch_size sequence_length n_heads head_dim"],
key: Float[Array, "batch_size sequence_length n_heads head_dim"],
mask: jax.Array = None,
) -> Float[Array, "batch_size sequence_length n_heads head_dim"]:
"""Self attention mechanism."""
kv_heads = key.shape[-2]
q_heads, head_dim = query.shape[-2], query.shape[-1]

if q_heads != kv_heads:
assert q_heads > kv_heads
tile_factor = q_heads // kv_heads
key = jnp.repeat(key, tile_factor, axis=-2)
value = jnp.repeat(value, tile_factor, axis=-2)

scale = float(1 / math.sqrt(head_dim))

attention_logits = jnp.einsum("bthd,bThd->bhtT", query, key)
attention_logits = (attention_logits * scale).astype(query.dtype)

attention_weights = jax.nn.softmax(attention_logits, axis=-1)
attention_weights = attention_weights.astype(value.dtype)

attention_vec = jnp.einsum("bhtT,bThd->bthd", attention_weights, value)

return attention_vec
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jmp = "0.0.4"
jaxtyping = "0.2.28"
pre-commit = "3.7.0"
ipdb = "0.13.13"
beartype = "0.18.5"



Expand Down Expand Up @@ -53,6 +54,9 @@ select = [
# isort
"I",
]
ignore = [
"F722" # forward-annotation-syntax-error (F722)
]


[tool.ruff.format]
Expand Down
30 changes: 30 additions & 0 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest

import jax.numpy as jnp

from fanan.modules.attentions import self_attention


class TestAttention(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.sequence_length = 32
self.n_heads = 8
self.head_dim = 4

def test_self_attention(self):
query = jnp.ones((self.batch_size, self.sequence_length, self.n_heads, self.head_dim), dtype=jnp.float32)
value = jnp.ones((self.batch_size, self.sequence_length, self.n_heads, self.head_dim), dtype=jnp.float32)
key = jnp.ones((self.batch_size, self.sequence_length, self.n_heads, self.head_dim), dtype=jnp.float32)

result = self_attention(query, value, key)

expected_shape = (self.batch_size, self.sequence_length, self.n_heads, self.head_dim)
self.assertEqual(result.shape, expected_shape)

self.assertTrue(jnp.all(result >= 0))
self.assertTrue(jnp.all(result <= 1))


if __name__ == "__main__":
unittest.main()

0 comments on commit d2b29ca

Please sign in to comment.