diff --git a/.github/tests.yaml b/.github/tests.yaml new file mode 100644 index 0000000..648d493 --- /dev/null +++ b/.github/tests.yaml @@ -0,0 +1,28 @@ +name: Unittests + +on: [push, pull_request] + +jobs: + test: + name: Run Unittests + runs-on: ubuntu-latest + + steps: + - name: Checkout Repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.x + + - name: Install Poetry + run: | + curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python + export PATH="$HOME/.poetry/bin:$PATH" + + - name: Install Dependencies + run: poetry install + + - name: Run Unittests + run: poetry run python -m unittest discover -s tests -v diff --git a/fanan/modules/attentions/__init__.py b/fanan/modules/attentions/__init__.py new file mode 100644 index 0000000..4838032 --- /dev/null +++ b/fanan/modules/attentions/__init__.py @@ -0,0 +1,16 @@ +from typing import Any, Dict + +_ATTENTIONS: Dict[str, Any] = {} + + +def register_attention_fn(fn): + _ATTENTIONS[fn.__name__.lower()] = fn + return fn + + +from fanan.modules.attentions.self_attention import * # noqa: E402, F403 + + +def get_attention_fn(name: str): + assert name in _ATTENTIONS, f"attention fn {name=} is not supported. Available attentions: {_ATTENTIONS.keys()}" + return _ATTENTIONS[name.lower()] diff --git a/fanan/modules/attentions/self_attention.py b/fanan/modules/attentions/self_attention.py new file mode 100644 index 0000000..270fa2c --- /dev/null +++ b/fanan/modules/attentions/self_attention.py @@ -0,0 +1,39 @@ +import math + +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Array, Float, jaxtyped + +from fanan.modules.attentions import register_attention_fn + + +@register_attention_fn +@jaxtyped(typechecker=typechecker) +def self_attention( + query: Float[Array, "batch_size sequence_length n_heads head_dim"], + value: Float[Array, "batch_size sequence_length n_heads head_dim"], + key: Float[Array, "batch_size sequence_length n_heads head_dim"], + mask: jax.Array = None, +) -> Float[Array, "batch_size sequence_length n_heads head_dim"]: + """Self attention mechanism.""" + kv_heads = key.shape[-2] + q_heads, head_dim = query.shape[-2], query.shape[-1] + + if q_heads != kv_heads: + assert q_heads > kv_heads + tile_factor = q_heads // kv_heads + key = jnp.repeat(key, tile_factor, axis=-2) + value = jnp.repeat(value, tile_factor, axis=-2) + + scale = float(1 / math.sqrt(head_dim)) + + attention_logits = jnp.einsum("bthd,bThd->bhtT", query, key) + attention_logits = (attention_logits * scale).astype(query.dtype) + + attention_weights = jax.nn.softmax(attention_logits, axis=-1) + attention_weights = attention_weights.astype(value.dtype) + + attention_vec = jnp.einsum("bhtT,bThd->bthd", attention_weights, value) + + return attention_vec diff --git a/poetry.lock b/poetry.lock index 8ee4b13..0d34cb7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -182,6 +182,24 @@ tests = ["attrs[tests-no-zope]", "zope-interface"] tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] +[[package]] +name = "beartype" +version = "0.18.5" +description = "Unbearably fast runtime type checking in pure Python." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "beartype-0.18.5-py3-none-any.whl", hash = "sha256:5301a14f2a9a5540fe47ec6d34d758e9cd8331d36c4760fc7a5499ab86310089"}, + {file = "beartype-0.18.5.tar.gz", hash = "sha256:264ddc2f1da9ec94ff639141fbe33d22e12a9f75aa863b83b7046ffff1381927"}, +] + +[package.extras] +all = ["typing-extensions (>=3.10.0.0)"] +dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "equinox", "mypy (>=0.800)", "numpy", "pandera", "pydata-sphinx-theme (<=0.7.2)", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "tox (>=3.20.1)", "typing-extensions (>=3.10.0.0)"] +doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"] +test-tox = ["equinox", "mypy (>=0.800)", "numpy", "pandera", "pytest (>=4.0.0)", "sphinx", "typing-extensions (>=3.10.0.0)"] +test-tox-coverage = ["coverage (>=5.5)"] + [[package]] name = "cachetools" version = "5.3.3" @@ -3523,4 +3541,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "3.11.9" -content-hash = "6d9d089afbe4bf1b4e2c7939e22d9c5306747dee6de438bc61b885fb5162cb89" +content-hash = "13e9009e8a56b090641240e09a5d3c2e768e9c2d7ce74797976c9aa7b7d4796a" diff --git a/pyproject.toml b/pyproject.toml index db00032..a044bb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ jmp = "0.0.4" jaxtyping = "0.2.28" pre-commit = "3.7.0" ipdb = "0.13.13" +beartype = "0.18.5" @@ -53,6 +54,9 @@ select = [ # isort "I", ] +ignore = [ + "F722" # forward-annotation-syntax-error (F722) +] [tool.ruff.format] diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000..a775f33 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,30 @@ +import unittest + +import jax.numpy as jnp + +from fanan.modules.attentions import self_attention + + +class TestAttention(unittest.TestCase): + def setUp(self): + self.batch_size = 8 + self.sequence_length = 32 + self.n_heads = 8 + self.head_dim = 4 + + def test_self_attention(self): + query = jnp.ones((self.batch_size, self.sequence_length, self.n_heads, self.head_dim), dtype=jnp.float32) + value = jnp.ones((self.batch_size, self.sequence_length, self.n_heads, self.head_dim), dtype=jnp.float32) + key = jnp.ones((self.batch_size, self.sequence_length, self.n_heads, self.head_dim), dtype=jnp.float32) + + result = self_attention(query, value, key) + + expected_shape = (self.batch_size, self.sequence_length, self.n_heads, self.head_dim) + self.assertEqual(result.shape, expected_shape) + + self.assertTrue(jnp.all(result >= 0)) + self.assertTrue(jnp.all(result <= 1)) + + +if __name__ == "__main__": + unittest.main()