diff --git a/modules.py b/modules.py index 34bf4c9..43ebbe9 100644 --- a/modules.py +++ b/modules.py @@ -103,7 +103,7 @@ def scaled_dot_product_attention(Q, K, V, key_masks, def mask(inputs, key_masks=None, type=None): """Masks paddings on keys or queries to inputs inputs: 3d tensor. (h*N, T_q, T_k) - key_masks: 3d tensor. (N, 1, T_k) + key_masks: 2d tensor. (N, T_k) type: string. "key" | "future" e.g., @@ -303,4 +303,4 @@ def noam_scheme(init_lr, global_step, warmup_steps=4000.): until it reaches init_lr. ''' step = tf.cast(global_step + 1, dtype=tf.float32) - return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5) \ No newline at end of file + return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5)