-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpooling.py
40 lines (34 loc) · 1.26 KB
/
pooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tensorflow as tf
class MaskGlobalMaxPooling1D(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(MaskGlobalMaxPooling1D, self).__init__(**kwargs)
def call(self, inputs, mask=None):
if mask is None:
mask = 1.0
else:
# 扩展维度便于广播
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
x = inputs
x = x - (1 - mask) * 1e12 # 用一个大的负数mask
x = tf.reduce_max(x, axis=1, keepdims=True)
# ws = tf.where(inputs == x, x, 0.0)
# ws = tf.reduce_sum(ws, axis=2)
x = tf.squeeze(x, axis=1)
return x
class MaskGlobalAveragePooling1D(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(MaskGlobalAveragePooling1D, self).__init__(**kwargs)
def call(self, inputs, mask=None):
if mask is None:
mask = 1.0
else:
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
x = inputs
x = x * mask
x = tf.reduce_sum(x, axis=1)
x = x / tf.reduce_sum(mask, axis=1)
# ws = tf.square(inputs - tf.expand_dims(x, axis=1))
# ws = tf.reduce_mean(ws, axis=2)
# ws = ws + (1 - mask) * 1e12
# ws = 1 / ws
return x