forked from ZiyaoGeng/RecLearn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modules.py
65 lines (47 loc) · 2.06 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Created on Oct 26, 2020
modules of DIN: attention mechanism
@author: Ziyao Geng
"""
import tensorflow as tf
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Layer, BatchNormalization, Dense
class Attention_Layer(Layer):
def __init__(self, att_hidden_units, activation='prelu'):
"""
"""
super(Attention_Layer, self).__init__()
self.att_dense = [Dense(unit, activation=activation) for unit in att_hidden_units]
self.att_final_dense = Dense(1)
def call(self, inputs):
# query: candidate item (None, d * 2), d is the dimension of embedding
# key: hist items (None, seq_len, d * 2)
# value: hist items (None, seq_len, d * 2)
# mask: (None, seq_len)
q, k, v, mask = inputs
q = tf.tile(q, multiples=[1, k.shape[1]]) # (None, seq_len * d * 2)
q = tf.reshape(q, shape=[-1, k.shape[1], k.shape[2]]) # (None, seq_len, d * 2)
# q, k, out product should concat
info = tf.concat([q, k, q - k, q * k], axis=-1)
# dense
for dense in self.att_dense:
info = dense(info)
outputs = self.att_final_dense(info) # (None, seq_len, 1)
outputs = tf.squeeze(outputs, axis=-1) # (None, seq_len)
paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) # (None, seq_len)
outputs = tf.where(tf.equal(mask, 0), paddings, outputs) # (None, seq_len)
# softmax
outputs = tf.nn.softmax(logits=outputs) # (None, seq_len)
outputs = tf.expand_dims(outputs, axis=1) # None, 1, seq_len)
outputs = tf.matmul(outputs, v) # (None, 1, d * 2)
outputs = tf.squeeze(outputs, axis=1)
return outputs
class Dice(Layer):
def __init__(self):
super(Dice, self).__init__()
self.bn = BatchNormalization(center=False, scale=False)
self.alpha = self.add_weight(shape=(), dtype=tf.float32, name='alpha')
def call(self, x):
x_normed = self.bn(x)
x_p = tf.sigmoid(x_normed)
return self.alpha * (1.0 - x_p) * x + x_p * x