Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : First implementation of sliding window local sparse attention. #951

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
90 changes: 80 additions & 10 deletions opennmt/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def matmul_with_relative_representations(a, b, transpose_b=False):
return c


def split_chunks(a, chunk_length, concat_3_chunks=True):
def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0):
"""Splits a tensor into chunks along the timesteps axis.

Args:
Expand All @@ -129,6 +129,10 @@ def split_chunks(a, chunk_length, concat_3_chunks=True):
A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number.
"""

if global_length:
global_a = a[:, :, :global_length, :]
a = a[:, :, global_length:, :]

batch, num_heads, timesteps, units_per_head = misc.shape_list(a)

# Pad to a factor of chunk_length.
Expand Down Expand Up @@ -157,9 +161,17 @@ def split_chunks(a, chunk_length, concat_3_chunks=True):
# Transpose and flatten first dimension (batch * num_chunks).
# batch, num_chunks, num_heads, chunk_length (*3), units_per_head
a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4])

if global_length:
# batch, num_chunks, num_heads, global timesteps, units_per_head
expanded_global_a = tf.tile(
tf.expand_dims(global_a, 1), [1, num_chunks, 1, 1, 1]
)
a_transposed = tf.concat([a_transposed, expanded_global_a], axis=3)

input_shape = misc.shape_list(a_transposed)
output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0)
# batch x num_chunks, num_heads, chunk_length (*3), units_per_head
# batch x num_chunks, num_heads, chunk_length (*3) + global_length, units_per_head
return tf.reshape(a_transposed, output_shape), num_chunks


Expand Down Expand Up @@ -194,7 +206,7 @@ def combine_chunks(a, num_chunks, unchunked_length):
return a[:, :, :unchunked_length, :]


def chunk_att_mask(mask, chunk_length):
def chunk_att_mask(mask, chunk_length, global_length=0):
"""Transforms an attention mask into a chunked representation.

Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``.
Expand All @@ -207,6 +219,10 @@ def chunk_att_mask(mask, chunk_length):
A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks.
"""

if global_length:
global_mask = mask[:, :global_length]
mask = mask[:, global_length:]

mask_shape = misc.shape_list(mask)
batch = mask_shape[0]
timesteps = mask_shape[-1]
Expand Down Expand Up @@ -254,9 +270,17 @@ def chunk_att_mask(mask, chunk_length):
mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3]
)

if global_length:
# batch, num_chunks, chunk_length, global_length
expanded_global_mask = tf.tile(
global_mask[:, tf.newaxis, tf.newaxis, :], [1, chunk_num, chunk_length, 1]
)
mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3)

# Flatten the first dimension to batch * chunk_num.
return tf.reshape(
mask_unskewed, shape=[batch * chunk_num, chunk_length, chunk_length * 3]
mask_unskewed,
shape=[batch * chunk_num, chunk_length, chunk_length * 3 + global_length],
)


Expand Down Expand Up @@ -359,6 +383,7 @@ def __init__(
maximum_relative_position=None,
max_length_full_attention=None,
local_attention_radius=None,
global_attention_length=0,
**kwargs
):
"""Initializes this layer.
Expand All @@ -374,6 +399,7 @@ def __init__(
If not ``None``, use sparse attention for longer sequences
(from https://arxiv.org/abs/2004.08483).
local_attention_radius: Attention radius around each token for local sliding attention.
global_attention_length: Number of tokens used for global attention with sparse attention.
kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
Expand All @@ -393,6 +419,7 @@ def __init__(
self.maximum_relative_position = maximum_relative_position
self.max_length_full_attention = max_length_full_attention
self.local_attention_radius = local_attention_radius
self.global_attention_length = global_attention_length

def map_v1_weights(self, weights):
# V1 used conv1d layers that have a leading dimensions.
Expand Down Expand Up @@ -513,17 +540,33 @@ def _compute_kv(x):
raise ValueError("Sparse attention only supports self-attention.")
if self.maximum_relative_position is not None:
raise ValueError("Sparse attention doesn't support relative positions.")
if self.return_attention:
raise ValueError(
"Cannot return attention weights when using sparse attention."
)

use_sparse_att = queries_length > self.max_length_full_attention

chunk_length = self.local_attention_radius
# Dot product attention.
if use_sparse_att:
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
if self.global_attention_length:
global_queries = queries[:, :, : self.global_attention_length, :]
queries = queries[:, :, self.global_attention_length :, :]
global_keys = keys
global_values = values
global_dot = tf.matmul(global_queries, global_keys, transpose_b=True)

# batch x num_chunks, num_heads, chunk_length, units_per_head
queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False)
# batch x num_chunks, num_heads, chunk_length*3, units_per_head
keys, _ = split_chunks(keys, chunk_length)
# batch x num_chunks, num_heads, chunk_length*3, units_per_head
values, num_chunks = split_chunks(values, chunk_length)
# batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head
keys, _ = split_chunks(
keys, chunk_length, global_length=self.global_attention_length
)
# batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head
values, num_chunks = split_chunks(
values, chunk_length, global_length=self.global_attention_length
)
dot = tf.matmul(queries, keys, transpose_b=True)
if relative_repr_keys is not None:
dot += matmul_with_relative_representations(
Expand All @@ -532,19 +575,39 @@ def _compute_kv(x):
if mask is not None:
mask = tf.cast(mask, tf.float32)
if use_sparse_att:
mask = chunk_att_mask(mask, chunk_length)
global_mask = mask[:, tf.newaxis, tf.newaxis, :]
mask = chunk_att_mask(mask, chunk_length, self.global_attention_length)
elif mask.shape.rank == 2:
mask = tf.expand_dims(mask, 1) # Broadcast on time dimension.
mask = tf.expand_dims(mask, 1) # Broadcast on head dimension.
dot = tf.cast(
tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min),
dot.dtype,
)
if self.global_attention_length:
global_dot = tf.cast(
tf.cast(global_dot, tf.float32) * global_mask
+ ((1.0 - global_mask) * tf.float32.min),
global_dot.dtype,
)

attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
drop_attn = common.dropout(attn, self.dropout, training=training)
heads = tf.matmul(drop_attn, values)

if self.global_attention_length:
global_attn = tf.cast(
tf.nn.softmax(tf.cast(global_dot, tf.float32)), global_dot.dtype
)
global_drop_attn = common.dropout(
global_attn, self.dropout, training=training
)
global_heads = tf.matmul(global_drop_attn, global_values)

if use_sparse_att:
heads = combine_chunks(heads, num_chunks, queries_length)
heads = combine_chunks(
heads, num_chunks, queries_length - self.global_attention_length
)
if relative_repr_values is not None:
heads += matmul_with_relative_representations(
drop_attn, relative_repr_values
Expand All @@ -553,6 +616,13 @@ def _compute_kv(x):
# Concatenate all heads output.
combined = combine_heads(heads)
outputs = self.linear_output(combined)
if self.global_attention_length:
global_combined = combine_heads(global_heads)
global_outputs = self.linear_output(
global_combined
) # TODO : a separate global linear input and output layers ?
outputs = tf.concat((global_outputs, outputs), axis=1)

if self.return_attention:
return outputs, cache, attn
return outputs, cache
Expand Down
82 changes: 69 additions & 13 deletions opennmt/tests/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,23 @@ def testRelativePositions(self):
[[2, 3, 4, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2]],
)

@parameterized.expand([[2, True], [2, False], [3, True], [3, False]])
def testSplitChunks(self, chunk_length, concat_3_chunks):
@parameterized.expand(
[
[2, True],
[2, False],
[3, True],
[3, False],
[2, True, 1],
[2, False, 1],
[3, True, 1],
[3, False, 1],
[2, True, 2],
[2, False, 2],
[3, True, 2],
[3, False, 2],
]
)
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
def testSplitChunks(self, chunk_length, concat_3_chunks, global_length=0):
batch = 3
length = [5, 3, 7]
num_heads = 4
Expand All @@ -127,33 +142,62 @@ def testSplitChunks(self, chunk_length, concat_3_chunks):
[batch, num_heads, max(length), depth], dtype=tf.float32
)
split, num_chunks = transformer.split_chunks(
inputs, chunk_length=chunk_length, concat_3_chunks=concat_3_chunks
inputs,
chunk_length=chunk_length,
concat_3_chunks=concat_3_chunks,
global_length=global_length,
)
split_shape = split.shape
self.assertEqual(num_chunks, split_shape[0] / batch)
self.assertEqual(num_heads, split_shape[1])
chunk_length_eval = chunk_length * 3 if concat_3_chunks else chunk_length
chunk_length_eval += global_length
self.assertEqual(chunk_length_eval, split_shape[2])
self.assertEqual(depth, split_shape[3])

@parameterized.expand(
[[tf.bool, 2], [tf.float32, 2], [tf.bool, 3], [tf.float32, 3]]
[
[tf.bool, 2],
[tf.float32, 2],
[tf.bool, 3],
[tf.float32, 3],
[tf.bool, 2, 1],
[tf.float32, 2, 1],
[tf.bool, 3, 1],
[tf.float32, 3, 1],
[tf.bool, 2, 2],
[tf.float32, 2, 2],
[tf.bool, 3, 2],
[tf.float32, 3, 2],
]
)
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
def testChunkAttentionMask(self, dtype, chunk_length):
def testChunkAttentionMask(self, dtype, chunk_length, global_length=0):
length = [2, 4, 3]
batch = len(length)
maximum_length = 5
mask = tf.sequence_mask(lengths=length, maxlen=maximum_length, dtype=dtype)
mask_chunked = transformer.chunk_att_mask(mask, chunk_length=chunk_length)
output_shape = mask_chunked.shape
num_chunks = abs(-maximum_length // chunk_length)
self.assertEqual(num_chunks, output_shape[0] / batch)
self.assertEqual(chunk_length, output_shape[1])
self.assertEqual(chunk_length * 3, output_shape[2])
mask_chunked = transformer.chunk_att_mask(
mask, chunk_length=chunk_length, global_length=global_length
)
(
output_batch_times_chunks,
output_chunk_length,
output_expanded_chunk_length,
) = mask_chunked.shape
if global_length:
maximum_length = maximum_length - global_length
length = [el - global_length for el in length]
num_chunks = abs(-(maximum_length) // chunk_length)
self.assertEqual(num_chunks * batch, output_batch_times_chunks)
self.assertEqual(chunk_length, output_chunk_length)
self.assertEqual(chunk_length * 3 + global_length, output_expanded_chunk_length)

self.assertIs(mask_chunked.dtype, dtype)

expected = np.zeros(output_shape, dtype=dtype.as_numpy_dtype)
expected = np.zeros(
(output_batch_times_chunks, output_chunk_length, chunk_length * 3),
dtype=dtype.as_numpy_dtype,
)

token_radius = chunk_length * 2 + 1
for b in range(batch):
Expand All @@ -170,6 +214,14 @@ def testChunkAttentionMask(self, dtype, chunk_length):
expected[chunk_idx][ch_l][start_idx:end_idx] = 1

mask_chunked = self.evaluate(mask_chunked)
if global_length:
expanded_mask = np.repeat(mask, num_chunks, axis=0)
expanded_mask = np.repeat(
expanded_mask[:, np.newaxis, :], chunk_length, axis=1
)
expected = tf.concat(
(expected, expanded_mask[:, :, :global_length]), axis=2
)
self.assertAllEqual(mask_chunked, expected)

def testFeedForwardNetwork(self):
Expand Down Expand Up @@ -217,7 +269,11 @@ def testMultiHeadSelfAttentionRelativePositionsWithCache(self):

def testMultiHeadSelfAttentionSparse(self):
attention = transformer.MultiHeadAttention(
4, 20, local_attention_radius=2, max_length_full_attention=3
4,
20,
local_attention_radius=2,
max_length_full_attention=3,
global_attention_length=2,
)
x = tf.random.uniform([2, 9, 10])
mask = tf.sequence_mask([9, 7])
Expand Down