-
Notifications
You must be signed in to change notification settings - Fork 512
/
attention_tf.py
87 lines (81 loc) · 3.51 KB
/
attention_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#! -*- coding: utf-8 -*-
import tensorflow as tf
'''
inputs是一个形如(batch_size, seq_len, word_size)的张量;
函数返回一个形如(batch_size, seq_len, position_size)的位置张量。
'''
def Position_Embedding(inputs, position_size):
batch_size,seq_len = tf.shape(inputs)[0],tf.shape(inputs)[1]
position_j = 1. / tf.pow(10000., \
2 * tf.range(position_size / 2, dtype=tf.float32 \
) / position_size)
position_j = tf.expand_dims(position_j, 0)
position_i = tf.range(tf.cast(seq_len, tf.float32), dtype=tf.float32)
position_i = tf.expand_dims(position_i, 1)
position_ij = tf.matmul(position_i, position_j)
position_ij = tf.concat([tf.cos(position_ij), tf.sin(position_ij)], 1)
position_embedding = tf.expand_dims(position_ij, 0) \
+ tf.zeros((batch_size, seq_len, position_size))
return position_embedding
'''
inputs是一个二阶以上的张量,代表输入序列,比如形如(batch_size, seq_len, input_size)的张量;
seq_len是一个形如(batch_size,)的张量,代表每个序列的实际长度,多出部分都被忽略;
mode分为mul和add,mul是指把多出部分全部置零,一般用于全连接层之前;
add是指把多出部分全部减去一个大的常数,一般用于softmax之前。
'''
def Mask(inputs, seq_len, mode='mul'):
if seq_len == None:
return inputs
else:
mask = tf.cast(tf.sequence_mask(seq_len), tf.float32)
for _ in range(len(inputs.shape)-2):
mask = tf.expand_dims(mask, 2)
if mode == 'mul':
return inputs * mask
if mode == 'add':
return inputs - (1 - mask) * 1e12
'''
普通的全连接
inputs是一个二阶或二阶以上的张量,即形如(batch_size,...,input_size)。
只对最后一个维度做矩阵乘法,即输出一个形如(batch_size,...,ouput_size)的张量。
'''
def Dense(inputs, ouput_size, bias=True, seq_len=None):
input_size = int(inputs.shape[-1])
W = tf.Variable(tf.random_uniform([input_size, ouput_size], -0.05, 0.05))
if bias:
b = tf.Variable(tf.random_uniform([ouput_size], -0.05, 0.05))
else:
b = 0
outputs = tf.matmul(tf.reshape(inputs, (-1, input_size)), W) + b
outputs = tf.reshape(outputs, \
tf.concat([tf.shape(inputs)[:-1], [ouput_size]], 0)
)
if seq_len != None:
outputs = Mask(outputs, seq_len, 'mul')
return outputs
'''
Multi-Head Attention的实现
'''
def Attention(Q, K, V, nb_head, size_per_head, Q_len=None, V_len=None):
#对Q、K、V分别作线性映射
Q = Dense(Q, nb_head * size_per_head, False)
Q = tf.reshape(Q, (-1, tf.shape(Q)[1], nb_head, size_per_head))
Q = tf.transpose(Q, [0, 2, 1, 3])
K = Dense(K, nb_head * size_per_head, False)
K = tf.reshape(K, (-1, tf.shape(K)[1], nb_head, size_per_head))
K = tf.transpose(K, [0, 2, 1, 3])
V = Dense(V, nb_head * size_per_head, False)
V = tf.reshape(V, (-1, tf.shape(V)[1], nb_head, size_per_head))
V = tf.transpose(V, [0, 2, 1, 3])
#计算内积,然后mask,然后softmax
A = tf.matmul(Q, K, transpose_b=True) / tf.sqrt(float(size_per_head))
A = tf.transpose(A, [0, 3, 2, 1])
A = Mask(A, V_len, mode='add')
A = tf.transpose(A, [0, 3, 2, 1])
A = tf.nn.softmax(A)
#输出并mask
O = tf.matmul(A, V)
O = tf.transpose(O, [0, 2, 1, 3])
O = tf.reshape(O, (-1, tf.shape(O)[1], nb_head * size_per_head))
O = Mask(O, Q_len, 'mul')
return O