Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : First implementation of sliding window local sparse attention. #951

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
245 changes: 244 additions & 1 deletion opennmt/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,173 @@ def matmul_with_relative_representations(a, b, transpose_b=False):
return c


def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0):
"""Splits a tensor into chunks along the timesteps axis.

Args:
a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`.
chunk_length: The length of a chunk :math:`C`.
concat_3_chunks: Optional, if ``True``, append previous and following chunks to each chunk.

Returns:
A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number.
"""

if global_length:
global_a = a[:, :, :global_length, :]
a = a[:, :, global_length:, :]

batch, num_heads, timesteps, units_per_head = misc.shape_list(a)

# Pad to a factor of chunk_length.
pad_len = -timesteps % chunk_length
# batch, num_heads, timesteps padded, units_per_head
a_padded = tf.pad(tensor=a, paddings=[[0, 0], [0, 0], [0, pad_len], [0, 0]])
padded_len = misc.shape_list(a_padded)[2]

# Chunk along timesteps axis.
num_chunks = padded_len // chunk_length
chunked_shape = [batch, num_heads, num_chunks, chunk_length, units_per_head]
# batch, num_heads, num_chunks, chunk_length, units_per_head
a_chunked = tf.reshape(a_padded, chunked_shape)

# Concatenate previous and next chunk to each chunk, for overlapping.
if concat_3_chunks:
# batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head
a_chunked_padded = tf.pad(
a_chunked, paddings=[[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]]
)
# batch, num_heads, num_chunks, chunk_length*3, units_per_head
a_chunked = tf.concat(
[a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3
)

# Transpose and flatten first dimension (batch * num_chunks).
# batch, num_chunks, num_heads, chunk_length (*3), units_per_head
a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4])

if global_length:
# batch, num_chunks, num_heads, global timesteps, units_per_head
expanded_global_a = tf.tile(
tf.expand_dims(global_a, 1), [1, num_chunks, 1, 1, 1]
)
a_transposed = tf.concat([a_transposed, expanded_global_a], axis=3)

input_shape = misc.shape_list(a_transposed)
output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0)
# batch x num_chunks, num_heads, chunk_length (*3) + global_length, units_per_head
return tf.reshape(a_transposed, output_shape), num_chunks


def combine_chunks(a, num_chunks, unchunked_length):
# Unchunk
a_shape = misc.shape_list(a)
# batch, num_chunks, num_heads, chunk_length, self.num_units_per_head
a = tf.reshape(
a,
[
a_shape[0] // num_chunks,
num_chunks,
a_shape[1],
a_shape[2],
a_shape[3],
],
)
# batch, num_heads, num_chunks, chunk_length, self.num_units_per_head
a = tf.transpose(a, perm=[0, 2, 1, 3, 4])
a_shape = misc.shape_list(a)
a = tf.reshape(
a,
[
a_shape[0],
a_shape[1],
a_shape[2] * a_shape[3],
a_shape[4],
],
)

# Remove padding used for chunking.
return a[:, :, :unchunked_length, :]


def chunk_att_mask(mask, chunk_length, global_length=0):
"""Transforms an attention mask into a chunked representation.

Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``.

Args:
mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`.
chunk_length: The length of a chunk :math:`C`.

Returns:
A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks.
"""

if global_length:
global_mask = mask[:, :global_length]
mask = mask[:, global_length:]

mask_shape = misc.shape_list(mask)
batch = mask_shape[0]
timesteps = mask_shape[-1]
rank = len(mask_shape)

if rank == 2:
# Broadcast on queries time dimension.
mask = tf.expand_dims(mask, 1)
mask = tf.broadcast_to(mask, [batch, timesteps, timesteps])

# Pad to a factor of chunk_length.
pad_len = -timesteps % chunk_length
mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]])
padded_timesteps = misc.shape_list(mask)[-1]

# Append chunk_length padding to timestep axis, before and after.
mask_padded = tf.pad(
tensor=mask, paddings=[[0, 0], [0, 0], [chunk_length, chunk_length]]
)
padded_len = misc.shape_list(mask_padded)[-1]
mask_flattened = tf.reshape(mask_padded, shape=[batch, -1])

# Skew to the left by one and keep 2*chunk_length + 1 relevant locations.
# This corresponds to chunk_length radius around the diagonal.
skewed_len = padded_len + 1
skewed_padding_len = (
padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1]
)
mask_padded = tf.pad(mask_flattened, paddings=[[0, 0], [0, skewed_padding_len]])
skewed_shape = [batch, -1, skewed_len]
mask_skewed = tf.reshape(mask_padded, shape=skewed_shape)
mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1]

chunk_num = padded_timesteps // chunk_length
mask_skewed_chunked = tf.reshape(mask_skewed, [batch, chunk_num, chunk_length, -1])

# Unskew each chunk to be compatible with chunked attention shape.
unskewed_len = chunk_length * 3
mask_skewed_padded = tf.pad(
mask_skewed_chunked, paddings=[[0, 0], [0, 0], [0, 0], [0, chunk_length]]
)
mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1])
mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)]
mask_unskewed = tf.reshape(
mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3]
)

if global_length:
# batch, num_chunks, chunk_length, global_length
expanded_global_mask = tf.tile(
global_mask[:, tf.newaxis, tf.newaxis, :], [1, chunk_num, chunk_length, 1]
)
mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3)

# Flatten the first dimension to batch * chunk_num.
return tf.reshape(
mask_unskewed,
shape=[batch * chunk_num, chunk_length, chunk_length * 3 + global_length],
)


class FeedForwardNetwork(tf.keras.layers.Layer):
"""Implements the Transformer's "Feed Forward" layer.

Expand Down Expand Up @@ -214,6 +381,9 @@ def __init__(
dropout=0.1,
return_attention=False,
maximum_relative_position=None,
max_length_full_attention=None,
local_attention_radius=None,
global_attention_length=0,
**kwargs
):
"""Initializes this layer.
Expand All @@ -225,6 +395,11 @@ def __init__(
return_attention: If ``True``, also return the attention weights.
maximum_relative_position: Maximum relative position representation
(from https://arxiv.org/abs/1803.02155).
max_length_full_attention: Maximum sequence length for full attention.
If not ``None``, use sparse attention for longer sequences
(from https://arxiv.org/abs/2004.08483).
local_attention_radius: Attention radius around each token for local sliding attention.
global_attention_length: Number of tokens used for global attention with sparse attention.
kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
Expand All @@ -242,6 +417,9 @@ def __init__(
self.dropout = dropout
self.return_attention = return_attention
self.maximum_relative_position = maximum_relative_position
self.max_length_full_attention = max_length_full_attention
self.local_attention_radius = local_attention_radius
self.global_attention_length = global_attention_length

def map_v1_weights(self, weights):
# V1 used conv1d layers that have a leading dimensions.
Expand Down Expand Up @@ -354,24 +532,82 @@ def _compute_kv(x):

cache = (keys, values)

queries_length = misc.shape_list(queries)[2]

use_sparse_att = False
if self.max_length_full_attention is not None:
if memory is not None:
raise ValueError("Sparse attention only supports self-attention.")
if self.maximum_relative_position is not None:
raise ValueError("Sparse attention doesn't support relative positions.")
if self.return_attention:
raise ValueError(
"Cannot return attention weights when using sparse attention."
)

use_sparse_att = queries_length > self.max_length_full_attention

chunk_length = self.local_attention_radius
# Dot product attention.
if use_sparse_att:
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
if self.global_attention_length:
global_queries = queries[:, :, : self.global_attention_length, :]
queries = queries[:, :, self.global_attention_length :, :]
global_keys = keys
global_values = values
global_dot = tf.matmul(global_queries, global_keys, transpose_b=True)

# batch x num_chunks, num_heads, chunk_length, units_per_head
queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False)
# batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head
keys, _ = split_chunks(
keys, chunk_length, global_length=self.global_attention_length
)
# batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head
values, num_chunks = split_chunks(
values, chunk_length, global_length=self.global_attention_length
)
dot = tf.matmul(queries, keys, transpose_b=True)
if relative_repr_keys is not None:
dot += matmul_with_relative_representations(
queries, relative_repr_keys, transpose_b=True
)
if mask is not None:
mask = tf.cast(mask, tf.float32)
if mask.shape.rank == 2:
if use_sparse_att:
global_mask = mask[:, tf.newaxis, tf.newaxis, :]
mask = chunk_att_mask(mask, chunk_length, self.global_attention_length)
elif mask.shape.rank == 2:
mask = tf.expand_dims(mask, 1) # Broadcast on time dimension.
mask = tf.expand_dims(mask, 1) # Broadcast on head dimension.
dot = tf.cast(
tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min),
dot.dtype,
)
if self.global_attention_length:
global_dot = tf.cast(
tf.cast(global_dot, tf.float32) * global_mask
+ ((1.0 - global_mask) * tf.float32.min),
global_dot.dtype,
)

attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
drop_attn = common.dropout(attn, self.dropout, training=training)
heads = tf.matmul(drop_attn, values)

if self.global_attention_length:
global_attn = tf.cast(
tf.nn.softmax(tf.cast(global_dot, tf.float32)), global_dot.dtype
)
global_drop_attn = common.dropout(
global_attn, self.dropout, training=training
)
global_heads = tf.matmul(global_drop_attn, global_values)

if use_sparse_att:
heads = combine_chunks(
heads, num_chunks, queries_length - self.global_attention_length
)
if relative_repr_values is not None:
heads += matmul_with_relative_representations(
drop_attn, relative_repr_values
Expand All @@ -380,6 +616,13 @@ def _compute_kv(x):
# Concatenate all heads output.
combined = combine_heads(heads)
outputs = self.linear_output(combined)
if self.global_attention_length:
global_combined = combine_heads(global_heads)
global_outputs = self.linear_output(
global_combined
) # TODO : a separate global linear input and output layers ?
outputs = tf.concat((global_outputs, outputs), axis=1)

if self.return_attention:
return outputs, cache, attn
return outputs, cache
Expand Down
Loading