-
Notifications
You must be signed in to change notification settings - Fork 13
/
my_loss.py
94 lines (88 loc) · 4.58 KB
/
my_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import math_ops
import tensorflow as tf
__all__ = ["sequence_loss"]
def sequence_loss(logits, targets, weights, extra_information, label_embedding,
average_across_timesteps=True, average_across_batch=True,
softmax_loss_function=None, name=None):
"""Weighted cross-entropy loss for a sequence of logits (per example).
Args:
logits: A 3D Tensor of shape
[batch_size x sequence_length x num_decoder_symbols] and dtype float.
The logits correspond to the prediction across all classes at each
timestep.
targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype
int. The target represents the true class at each timestep.
weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype
float. Weights constitutes the weighting of each prediction in the
sequence. When using weights as masking set all valid timesteps to 1 and
all padded timesteps to 0.
average_across_timesteps: If set, sum the cost across the sequence
dimension and divide by the cost by the total label weight across
timesteps.
average_across_batch: If set, sum the cost across the batch dimension and
divide the returned cost by the batch size.
softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
to be used instead of the standard softmax (the default if this is None).
name: Optional name for this operation, defaults to "sequence_loss".
Returns:
A scalar float Tensor: The average log-perplexity per symbol (weighted).
Raises:
ValueError: logits does not have 3 dimensions or targets does not have 2
dimensions or weights does not have 2 dimensions.
"""
if len(logits.get_shape()) != 3:
raise ValueError("Logits must be a "
"[batch_size x sequence_length x logits] tensor")
if len(targets.get_shape()) != 2:
raise ValueError("Targets must be a [batch_size x sequence_length] "
"tensor")
if len(weights.get_shape()) != 2:
raise ValueError("Weights must be a [batch_size x sequence_length] "
"tensor")
with ops.name_scope(name, "sequence_loss", [logits, targets, weights]):
num_classes = array_ops.shape(logits)[2]
max_time = array_ops.shape(logits)[1]
batch_size = array_ops.shape(logits)[0]
latent_size = array_ops.shape(extra_information)[1]
embed_size = array_ops.shape(label_embedding)[1]
probs_flat = array_ops.reshape(logits, [-1, num_classes])
targets = array_ops.reshape(targets, [-1])
expand_extra_information = array_ops.reshape(tf.tile(extra_information, [1, max_time]), [batch_size*max_time, latent_size])
expand_label_embedding = array_ops.reshape(tf.tile(label_embedding, [1, max_time]), [batch_size*max_time, embed_size])
if softmax_loss_function is None:
crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=probs_flat)
else:
crossent = softmax_loss_function(probs_flat, targets, expand_extra_information, expand_label_embedding, max_time)
#crossent = crossent * array_ops.reshape(weights, [-1])
crossent = array_ops.reshape(crossent, [-1, max_time])
crossent = crossent * weights
if average_across_timesteps and average_across_batch:
#crossent = math_ops.reduce_sum(crossent, 1)
#total_size = math_ops.reduce_sum(weights, 1)
crossent = math_ops.reduce_sum(crossent)
total_size = math_ops.reduce_sum(weights)
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
#crossent = math_ops.reduce_mean(crossent)
else:
batch_size = array_ops.shape(logits)[0]
sequence_length = array_ops.shape(logits)[1]
crossent = array_ops.reshape(crossent, [batch_size, sequence_length])
if average_across_timesteps and not average_across_batch:
crossent = math_ops.reduce_sum(crossent, axis=[1])
total_size = math_ops.reduce_sum(weights, axis=[1])
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
if not average_across_timesteps and average_across_batch:
crossent = math_ops.reduce_sum(crossent, axis=[0])
total_size = math_ops.reduce_sum(weights, axis=[0])
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
return crossent