Skip to content

Commit

Permalink
split outside attention function (#106)
Browse files Browse the repository at this point in the history
* split outside attention function

* Update typing.py

* Update conf.py

* Update readthedocs.yaml

* Update docs.yml

* Update train_transformer.ipynb
  • Loading branch information
ASEM000 authored Apr 9, 2024
1 parent 13e6a95 commit a129c7b
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 41 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@ jobs:
pip install .
pip install -r docs/requirements.txt
- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v1

- name: Test doctests
- name: Run doctests
run: |
cd docs
make doctest
- name: Test docs to HTML
- name: Make HTML
run: |
cd docs
make html
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"notebook_interface": "jupyterlab",
"colab_url": "https://colab.research.google.com/",
},
"navigation_with_keys": False,
}

html_css_files = ["custom.css"]
Expand Down
12 changes: 8 additions & 4 deletions docs/notebooks/train_transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the dataset"
"### Load the dataset"
]
},
{
Expand All @@ -448,7 +448,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenizer"
"### Tokenizer"
]
},
{
Expand Down Expand Up @@ -489,7 +489,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataloader"
"### Dataloader"
]
},
{
Expand Down Expand Up @@ -946,7 +946,11 @@
" text_en = test_dataset[index][\"translation\"][\"en\"]\n",
" text_en_pred = translate_from_arabic_to_english(text_ar, key=jr.PRNGKey(0))\n",
"\n",
" print(f\"input arabic: {text_ar}\\n\" f\"true english: {text_en}\\n\" f\"pred english: {text_en_pred}\\n\")"
" print(\n",
" f\"input arabic: {text_ar}\\n\"\n",
" f\"true english: {text_en}\\n\"\n",
" f\"pred english: {text_en_pred}\\n\"\n",
" )"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "3.8"
python: "3.10"

sphinx:
builder: html
Expand Down
45 changes: 17 additions & 28 deletions serket/_src/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,45 +57,35 @@ def is_lazy_init(_, num_heads, q_features, *__, **___) -> bool:


def dot_product_attention(
q_heads: jax.Array,
k_heads: jax.Array,
v_heads: jax.Array,
num_heads: int,
mask: jax.Array | None,
q_heads: Annotated[jax.Array, "..., q_length, num_heads, head_features"],
k_heads: Annotated[jax.Array, "..., kv_length, num_heads, head_features"],
v_heads: Annotated[jax.Array, "..., kv_length, num_heads, head_features"],
mask: Annotated[jax.Array, "..., num_heads, q_length, kv_length"] | None,
drop_func: Callable[[jax.Array], jax.Array],
) -> jax.Array:
"""Applies multi-head attention to the given inputs.
Args:
q_input: Query input. [..., q_length, q_features]
k_input: Key input. [..., k_length, k_features]
v_input: Value input. [..., v_length, v_features]
q_input: Query input. [..., q_length, num_heads, head_features]
k_input: Key input. [..., k_length, num_heads, head_features]
v_input: Value input. [..., v_length, num_heads, head_features]
mask: Mask input. [..., num_heads, q_length, kv_length]. Use ``None``
for no masking.
num_heads: Number of attention heads.
drop_func: Dropout function. Takes a single input and returns a single output.
Use ``lambda input: input`` for no dropout.
Reference:
- https://keras.io/api/layers/attention_layers/multi_head_attention/
- https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html
"""
k_depth = k_heads.shape[-1]
# [..., q_length, head_features*num_heads] -> [..., q_length, num_heads, head_features]
q_heads = split_heads(q_heads, num_heads)
# [..., k_length, head_features*num_heads] -> [..., k_length, num_heads, head_features]
k_heads = split_heads(k_heads, num_heads)
# [..., v_length, head_features*num_heads] -> [..., v_length, num_heads, head_features]
v_heads = split_heads(v_heads, num_heads)

*_, num_heads, k_depth = k_heads.shape
logits = jnp.einsum("...qhd,...khd->...hqk", q_heads, k_heads)
logits /= jnp.sqrt(k_depth // num_heads)

min_num = jnp.finfo(logits.dtype).min
logits = logits if mask is None else jnp.where(mask, logits, min_num)

attention_weight = jax.nn.softmax(logits)
attention = jnp.einsum("...hqk,...khd->...qhd", attention_weight, v_heads)
weight = jax.nn.softmax(logits)
attention = jnp.einsum("...hqk,...khd->...qhd", weight, v_heads)
# avoid using Dropout layers inside functions
return merge_heads(drop_func(attention))


Expand Down Expand Up @@ -309,18 +299,17 @@ def __call__(
Defaults to ``None`` for no dropout.
"""

# [..., q_length, q_features] -> [..., q_length, head_features*num_heads]
q_heads = self.q_projection(q_input)
# [..., k_length, k_features] -> [..., k_length, head_features*num_heads]
k_heads = self.k_projection(k_input)
# [..., v_length, v_features] -> [..., v_length, head_features*num_heads]
v_heads = self.v_projection(v_input)
# [..., q_length, q_features] -> [..., q_length, head_features, num_heads]
q_heads = split_heads(self.q_projection(q_input), self.num_heads)
# [..., k_length, k_features] -> [..., k_length, head_features, num_heads]
k_heads = split_heads(self.k_projection(k_input), self.num_heads)
# [..., v_length, v_features] -> [..., v_length, head_features, num_heads]
v_heads = split_heads(self.v_projection(v_input), self.num_heads)

attention = self.attention_op(
q_heads=q_heads,
k_heads=k_heads,
v_heads=v_heads,
num_heads=self.num_heads,
mask=mask,
# note that if `tree_eval` is used, self.dropout is converted to an
# identity function, so the `key` argument is ignored.
Expand Down
2 changes: 1 addition & 1 deletion serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def generate_einsum_pattern(

def linear(
input: jax.Array,
weight: Any,
weight: jax.Array,
bias: jax.Array | None,
in_axis: Sequence[int] = (-1,),
out_axis: Sequence[int] = (-1,),
Expand Down
4 changes: 2 additions & 2 deletions serket/_src/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from __future__ import annotations

from typing import Annotated, Any, Callable, Literal, Sequence, Tuple, TypeVar, Union
from typing import Any, Callable, Literal, Sequence, Tuple, TypeVar, Union

import jax
import numpy as np
from typing_extensions import ParamSpec
from typing_extensions import Annotated, ParamSpec

KernelSizeType = Union[int, Sequence[int]]
StridesType = Union[int, Sequence[int]]
Expand Down

0 comments on commit a129c7b

Please sign in to comment.