Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

sliding window self-attention cell #1395

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions src/gluonnlp/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,163 @@ def __repr__(self):
dtype=self._dtype)




def multi_head_sliding_window_dot_attn(F, query, key, value, dilation, valid_length,
window_size: int, symmetric: bool = True,
dropout: float = 0.0, scaled: bool = True,
normalized: bool = False, eps: float = 1E-6,
query_head_units: Optional[int] = None,
layout: str = 'NKT',
dtype=np.float32):
"""Multihead sliding window attention between the query, key and value,
described at *Longformer: The Long-Document Transformer*,
available at https://arxiv.org/pdf/2004.05150.pdf.

Given a fixed window size *2w*, each token attends to *w* tokens on the left side
if using causal attention (setting *symmetric* to *False*),
otherwise each token attends to *w* tokens on each side.

Parameters
----------
F
query
Query. The shape is (batch_size, seq_length, num_heads, num_head_units)

key
Key. The shape is (batch_size, seq_length, num_heads, num_head_units)
value
Value. The shape is (batch_size, seq_length, num_heads, num_head_units)
dilation
Dilation. The shape is (num_heads,)
valid_length
Valid length. The shape is (batch_size,)
window_size
The one-sided window length.
symmetric
If False, each token can only attend to itself and the previous tokens.
dropout
Dropout rate
scaled
Whether to divide the attention weights by the sqrt of the query dimension.
normalized
If turned on, the cosine distance is used, i.e::

score = <h_q / ||h_q||, h_k / ||h_k||>

eps
The epsilon value used in L2 normalization
query_head_units
The units of each query head. If it's empty, we will estimate it via the
shape_array of the query.
layout
This stands for the layout of the attention cell. The shape of the input/output will depend
on the layout. Currently, we only support 'NTK' in which
'N' means the batch_size, 'K' means the head, and 'T' means the length dimension.

Returns
-------
context_vec
- (batch_size, seq_length, num_heads, num_head_units)
additional_info
scores:
Shape (batch_size, num_heads, seq_length, w + w + 1) if *symmetric* is True
Shape (batch_size, num_heads, seq_length, w + 1) otherwise
attn_weight:
Shape (batch_size, num_heads, seq_length, w + w + 1) if *symmetric* is True
Shape (batch_size, num_heads, seq_length, w + 1) otherwise
"""
if layout != "NTK":
raise NotImplementedError('We only support layout = "NTK".')
if normalized:
query = l2_normalize(F, query, axis=-1, eps=eps)
key = l2_normalize(F, key, axis=-1, eps=eps)
# 1. Calculate the attention weights
# scores' shape (batch_size, seq_length, num_heads, w + w + 1) if symmetric else
# (batch_size, seq_length, num_heads, w + 1)
scores = F.npx.sldwin_atten_score(query, key, dilation,
w=window_size, symmetric=symmetric)
if scaled:
if query_head_units is None:
query_shape = F.npx.shape_array(query)
scores = scores / F.np.sqrt(query_shape[-1])
else:
scores = scores / math.sqrt(query_head_units)
# mask's shape is the same as scores
mask = F.npx.sldwin_atten_mask_like(scores, dilation, valid_length.astype(np.int32),
w=window_size, symmetric=symmetric)
attn_weights = masked_softmax(F, scores, mask, dtype=dtype)
attn_weights = F.npx.dropout(attn_weights, p=dropout)
# 2. Calculate the context vector
# (batch_size, seq_length, num_heads, num_head_units)
context_vec = F.npx.sldwin_atten_context(attn_weights, value, dilation,
w=window_size, symmetric=symmetric)
# (batch_size, seq_length, num_units)
context_vec = F.npx.reshape(context_vec, (-2, -2, -1))

return context_vec, [scores, attn_weights]


class MultiHeadSlidingWindowAttentionCell(HybridBlock):
def __init__(self, window_size, symmetric=True, query_units=None, num_heads=None,
attention_dropout=0.0, scaled: bool = True, normalized: bool = False,
eps: float = 1E-6, dtype='float32', layout='NTK'):
super().__init__()
self._query_units = query_units
self._window_size = window_size
self._symmetric = symmetric
self._num_heads = num_heads
self._attention_dropout = attention_dropout
self._scaled = scaled
self._normalized = normalized
self._eps = eps
self._dtype = dtype
self._layout = layout
if self._query_units is not None:
assert self._num_heads is not None
assert self._query_units % self._num_heads == 0,\
'The units must be divisible by the number of heads.'
self._query_head_units = self._query_units // self._num_heads
else:
self._query_head_units = None

@property
def layout(self):
return self._layout

def hybrid_forward(self, F, query, key, value, dilation, valid_length):
return multi_head_sliding_window_dot_attn(F, query=query, key=key,
value=value, dilation=dilation,
valid_length=valid_length, window_size=self._window_size,
symmetric=self._symmetric, dropout=self._attention_dropout,
scaled=self._scaled, normalized=self._normalized, eps=self._eps,
query_head_units=self._query_head_units, layout=self._layout,
dtype=self._dtype)

def __repr__(self):
s = '{name}(\n' \
' window_size={window_size},\n' \
' symmetric={symmetric},\n' \
' query_units={query_units},\n' \
' num_heads={num_heads},\n' \
' attention_dropout={attention_dropout},\n' \
' scaled={scaled},\n' \
' normalized={normalized},\n' \
' layout="{layout}",\n' \
' dtype={dtype}\n' \
')'
return s.format(name=self.__class__.__name__,
window_size=self._window_size,
symmetric=self._symmetric,
query_units=self._query_units,
num_heads=self._num_heads,
attention_dropout=self._attention_dropout,
scaled=self._scaled,
normalized=self._normalized,
layout=self._layout,
dtype=self._dtype)


class RelAttentionScoreCell(HybridBlock):
"""Get the score based on the query and relative position index. This is used for implementing
relative attention.
Expand Down
77 changes: 76 additions & 1 deletion tests/test_attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from gluonnlp.attention_cell import\
multi_head_dot_attn, gen_self_attn_mask, gen_mem_attn_mask,\
MultiHeadAttentionCell,\
RelAttentionScoreCell
RelAttentionScoreCell,\
MultiHeadSlidingWindowAttentionCell
from gluonnlp.utils.parameter import grad_global_norm
mx.npx.set_np()

Expand Down Expand Up @@ -388,3 +389,77 @@ def test_multi_head_rel_attn_score(num_heads, method, bidirectional, hybridize,
assert_allclose(rel_score.asnumpy(), original_rel_score, 1E-5, 1E-5)
layout_query_grad_norm = np.linalg.norm(query.grad.asnumpy())
assert_allclose(layout_query_grad_norm, original_query_grad_norm, 1E-5, 1E-5)



def test_multi_head_sliding_window_dot_attention_cell():

def gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d):
"""Generate sliding_window attention mask for the full attention matrix ( seq_len^2 ).
"""
mask_np = np.zeros((batch_size, seq_length, seq_length))
for i in range(seq_length):
end = (i + 1 + w * d) if symmetric else (i + 1)
for j in range(i - w * d, end, d):
if j >= 0 and j < seq_length:
mask_np[:, i, j] = 1
return mask_np

def test_impl(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
attn_cell = MultiHeadAttentionCell()
sw_attn_cell = MultiHeadSlidingWindowAttentionCell(w, symmetric)
# Generate the data
query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
mask = gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d)
mask = mx.np.array(mask, dtype=np.float32)

query = mx.np.array(query, dtype=np.float32)
key = mx.np.array(key, dtype=np.float32)
value = mx.np.array(value, dtype=np.float32)

query.attach_grad()
key.attach_grad()
value.attach_grad()

with mx.autograd.record():
out, _ = attn_cell(query, key, value, mask)
out.backward()

out_np = out.asnumpy()
grad_query = query.grad.asnumpy()
grad_key = key.grad.asnumpy()
grad_value = value.grad.asnumpy()

query.grad[:] = 0
key.grad[:] = 0
value.grad[:] = 0

dilation = mx.np.zeros((num_heads,))
dilation[:] = d
dilation = mx.np.array(dilation, dtype=np.int32)
valid_length = np.zeros((batch_size,))
valid_length[:] = seq_length
valid_length = mx.np.array(valid_length, dtype=np.int32)

with mx.autograd.record():
sw_out, _ = sw_attn_cell(query, key, value, dilation, valid_length)
sw_out.backward()

sw_out_np = sw_out.asnumpy()
sw_grad_query = query.grad.asnumpy()
sw_grad_key = key.grad.asnumpy()
sw_grad_value = value.grad.asnumpy()

assert_allclose(sw_out_np, out_np, 1E-3, 1E-3)
assert_allclose(sw_grad_key, grad_key, 1E-3, 1E-3)
assert_allclose(sw_grad_value, grad_value, 1E-3, 1E-3)
assert_allclose(sw_grad_query, grad_query, 1E-3, 1E-3)

for symmetric in [True, False]:
for d in [1, 2, 3]:
test_impl(4, 128, 12, 64, 16, symmetric, d)
test_impl(1, 8, 2, 3, 2, symmetric, d)