-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a10d106
commit 0e455da
Showing
3 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
""" | ||
`continuiti.networks.multi_head_attention` | ||
Multi-Head-Attention in continuiti. | ||
""" | ||
|
||
from typing import Callable | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class MultiHeadAttention(nn.Module): | ||
r"""Multi-Head Attention module. | ||
Module as described in the paper [Attention is All you Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) | ||
with optional bias for the projections. | ||
$$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$ | ||
where | ||
$$head_i=Attention(QW_i^Q+b_i^Q, KW_i^K+b_i^K, VW_i^V+b_i^V).$$ | ||
Args: | ||
hidden_dim: dimension of the hidden layers (embedding dimension). | ||
n_heads: number of attention heads. | ||
attention: implementation of attention (defaults to scaled dot product attention). Needs to have the arguments | ||
`query`, `key`, `value`, `attn_mask`, and `dropout_p`. | ||
dropout_p: dropout probability. | ||
bias: If True, then the projection onto the different heads is performed with bias. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_dim: int, | ||
n_heads: int, | ||
attention: Callable = nn.functional.scaled_dot_product_attention, | ||
dropout_p: float = 0, | ||
bias: bool = True, | ||
): | ||
super().__init__() | ||
|
||
self.hidden_dim = hidden_dim | ||
self.n_heads = n_heads | ||
self.attention = attention | ||
self.dropout_p = dropout_p | ||
self.bias = bias | ||
|
||
self.head_dim = hidden_dim // n_heads | ||
assert ( | ||
self.head_dim * n_heads == hidden_dim | ||
), "hidden_dim must be divisible by n_heads" | ||
|
||
# projection networks | ||
self.query_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
self.key_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
self.value_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
self.out_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch, | ||
attn_mask: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
batch_size = query.size(0) | ||
|
||
# project values | ||
query = self.query_project(query) | ||
key = self.key_project(key) | ||
value = self.value_project(value) | ||
|
||
# form individual heads | ||
query = query.reshape(batch_size, self.n_heads, -1, self.head_dim) | ||
key = key.reshape(batch_size, self.n_heads, -1, self.head_dim) | ||
value = value.reshape(batch_size, self.n_heads, -1, self.head_dim) | ||
|
||
# perform attention | ||
attn_out = self.attention( | ||
query=query, | ||
key=key, | ||
value=value, | ||
attn_mask=attn_mask, | ||
dropout_p=self.dropout_p, | ||
) | ||
attn_out = attn_out.reshape(batch_size, -1, self.hidden_dim) | ||
|
||
# output projection | ||
return self.out_project(attn_out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
|
||
from continuiti.networks import MultiHeadAttention | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def some_multi_head_attn(): | ||
return MultiHeadAttention( | ||
hidden_dim=32, | ||
n_heads=4, | ||
attention=nn.functional.scaled_dot_product_attention, | ||
dropout_p=0.25, | ||
bias=True, | ||
) | ||
|
||
|
||
class TestMultiHeadAttention: | ||
def test_can_initialize(self, some_multi_head_attn): | ||
assert isinstance(some_multi_head_attn, MultiHeadAttention) | ||
|
||
def test_output_shape(self, some_multi_head_attn): | ||
batch_size = 3 | ||
hidden_dim = 32 | ||
query_size = 5 | ||
key_val_size = 7 | ||
|
||
query = torch.rand(batch_size, query_size, hidden_dim) | ||
key = torch.rand(batch_size, key_val_size, hidden_dim) | ||
val = torch.rand(batch_size, key_val_size, hidden_dim) | ||
|
||
out = some_multi_head_attn(query, key, val) | ||
correct_out = nn.functional.scaled_dot_product_attention(query, key, val) | ||
|
||
assert out.shape == correct_out.shape | ||
|
||
def test_attention_correct(self): | ||
"""Edge case testing for correctness.""" | ||
m_attn = MultiHeadAttention(4, 4, bias=False) | ||
|
||
batch_size = 3 | ||
hidden_dim = 4 | ||
query_size = 5 | ||
key_val_size = 7 | ||
|
||
query = torch.rand(batch_size, query_size, hidden_dim) | ||
key = torch.rand(batch_size, key_val_size, hidden_dim) | ||
torch.rand(batch_size, key_val_size, hidden_dim) | ||
|
||
# V = 0 -> attn score == 0 | ||
out = m_attn(query, key, torch.zeros(batch_size, key_val_size, hidden_dim)) | ||
assert torch.allclose(out, torch.zeros(out.shape)) | ||
|
||
def test_gradient_flow(self, some_multi_head_attn): | ||
hidden_size = 32 | ||
some_multi_head_attn.eval() # Turn off dropout or other stochastic processes | ||
query = key = value = torch.rand((10, 5, hidden_size), requires_grad=True) | ||
output = some_multi_head_attn( | ||
value, | ||
key, | ||
query, | ||
) | ||
output.sum().backward() | ||
|
||
assert query.grad is not None, "Gradients not flowing to query" | ||
assert key.grad is not None, "Gradients not flowing to key" | ||
assert value.grad is not None, "Gradients not flowing to value" |