Skip to content

Commit

Permalink
Add MultiHeadAttention.
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEliasWagner authored and Samuel Burbulla committed Apr 23, 2024
1 parent a10d106 commit 0e455da
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- Add `benchmarks` infrastructure.
- An `Operator` now takes a `device` argument.
- Add `QuantileScaler` class.
- Add `MultiHeadAttention` class.

## 0.0.0 (2024-02-22)

Expand Down
90 changes: 90 additions & 0 deletions src/continuiti/networks/multi_head_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
`continuiti.networks.multi_head_attention`
Multi-Head-Attention in continuiti.
"""

from typing import Callable
import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
r"""Multi-Head Attention module.
Module as described in the paper [Attention is All you Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
with optional bias for the projections.
$$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$
where
$$head_i=Attention(QW_i^Q+b_i^Q, KW_i^K+b_i^K, VW_i^V+b_i^V).$$
Args:
hidden_dim: dimension of the hidden layers (embedding dimension).
n_heads: number of attention heads.
attention: implementation of attention (defaults to scaled dot product attention). Needs to have the arguments
`query`, `key`, `value`, `attn_mask`, and `dropout_p`.
dropout_p: dropout probability.
bias: If True, then the projection onto the different heads is performed with bias.
"""

def __init__(
self,
hidden_dim: int,
n_heads: int,
attention: Callable = nn.functional.scaled_dot_product_attention,
dropout_p: float = 0,
bias: bool = True,
):
super().__init__()

self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.attention = attention
self.dropout_p = dropout_p
self.bias = bias

self.head_dim = hidden_dim // n_heads
assert (
self.head_dim * n_heads == hidden_dim
), "hidden_dim must be divisible by n_heads"

# projection networks
self.query_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.key_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.value_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.out_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch,
attn_mask: torch.Tensor = None,
) -> torch.Tensor:
batch_size = query.size(0)

# project values
query = self.query_project(query)
key = self.key_project(key)
value = self.value_project(value)

# form individual heads
query = query.reshape(batch_size, self.n_heads, -1, self.head_dim)
key = key.reshape(batch_size, self.n_heads, -1, self.head_dim)
value = value.reshape(batch_size, self.n_heads, -1, self.head_dim)

# perform attention
attn_out = self.attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=self.dropout_p,
)
attn_out = attn_out.reshape(batch_size, -1, self.hidden_dim)

# output projection
return self.out_project(attn_out)
68 changes: 68 additions & 0 deletions tests/networks/test_multi_head_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
import torch
import torch.nn as nn

from continuiti.networks import MultiHeadAttention


@pytest.fixture(scope="session")
def some_multi_head_attn():
return MultiHeadAttention(
hidden_dim=32,
n_heads=4,
attention=nn.functional.scaled_dot_product_attention,
dropout_p=0.25,
bias=True,
)


class TestMultiHeadAttention:
def test_can_initialize(self, some_multi_head_attn):
assert isinstance(some_multi_head_attn, MultiHeadAttention)

def test_output_shape(self, some_multi_head_attn):
batch_size = 3
hidden_dim = 32
query_size = 5
key_val_size = 7

query = torch.rand(batch_size, query_size, hidden_dim)
key = torch.rand(batch_size, key_val_size, hidden_dim)
val = torch.rand(batch_size, key_val_size, hidden_dim)

out = some_multi_head_attn(query, key, val)
correct_out = nn.functional.scaled_dot_product_attention(query, key, val)

assert out.shape == correct_out.shape

def test_attention_correct(self):
"""Edge case testing for correctness."""
m_attn = MultiHeadAttention(4, 4, bias=False)

batch_size = 3
hidden_dim = 4
query_size = 5
key_val_size = 7

query = torch.rand(batch_size, query_size, hidden_dim)
key = torch.rand(batch_size, key_val_size, hidden_dim)
torch.rand(batch_size, key_val_size, hidden_dim)

# V = 0 -> attn score == 0
out = m_attn(query, key, torch.zeros(batch_size, key_val_size, hidden_dim))
assert torch.allclose(out, torch.zeros(out.shape))

def test_gradient_flow(self, some_multi_head_attn):
hidden_size = 32
some_multi_head_attn.eval() # Turn off dropout or other stochastic processes
query = key = value = torch.rand((10, 5, hidden_size), requires_grad=True)
output = some_multi_head_attn(
value,
key,
query,
)
output.sum().backward()

assert query.grad is not None, "Gradients not flowing to query"
assert key.grad is not None, "Gradients not flowing to key"
assert value.grad is not None, "Gradients not flowing to value"

0 comments on commit 0e455da

Please sign in to comment.