-
Notifications
You must be signed in to change notification settings - Fork 0
/
constrained_dis.tex
80 lines (75 loc) · 6.8 KB
/
constrained_dis.tex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
We elaborate on the technical details of the objective functions required for training constrained disagreement classifiers.
\subsection{Disagreement Cross Entropy}\label{subsec:disagreement-cross-entropy}
\xhdr{Intuition}
In our methodology, we propose the \textit{disagreement cross entropy} (DCE) as a simple loss function that encourages a classifier to disagree with a target label while otherwise outputting a high entropy prediction.
DCE is equivalent to simply flipping the target label and computing the regular cross-entropy in the binary classification case.
Consider a model that outputs a distribution $\{p_1, p_2, 1-p_1-p_2\}$ over $3$ classes, suppose we would like to train it to \textbf{not} output high probability for class $3$ i.e., minimize $p_3=1-p_1-p_2$.
An intuitive approach would be to \textit{maximize} the standard cross entropy loss $-\log(1-p_1-p_2)$ that one would equivalently \textit{minimize} if trying to output high probability for class $3$.
We show in \autoref{fig:dce} that this objective is unstable whereas the DCE $-\frac{1}{2}(\log(p_1) + \log(p_2))$ is convex, and has a bounded minimum corresponding to $p_1=p_2=1/2$.
In practice if one wishes a loss function with a minimum of zero it suffices to subtract $ \log(N-1)$ from \autoref{eq:anitcross}.
A visual description of the DCE is provided in \autoref{fig:dce}.
\begin{figure}[!htb]
\centering
\includegraphics[width=\linewidth]{images/dce.pdf}
\caption{Consider a classifier outputs a distribution $\{p_1, p_2, 1-p_1-p_2\}$ over $3$ classes. We compare the optimization landscape induced by either (left) minimizing the negative cross entropy for the target $p_3=1-p_1-p_2$ (e.g trying not to predict class $3$) (center) minimizing our \textit{disagreement cross entropy} (DCE) for the same target
and (right) minimizing the regular cross entropy (e.g trying to predict class $3$). We observe that naively minimizing the negative cross entropy results in an unbounded minimum, while the DCE is significantly more stable and scales similarly to the regular cross entropy. Gradients are overlaid to help better visualize the 3D geometry.}
\label{fig:dce}
\end{figure}
\subsection{Validity of the DCE Measure}
We introduce a dentition of a \textit{a valid disagreement loss function} and show that the DCE loss satisfies it.
\begin{definition}[A Valid Disagreement Loss Function]
Let $P$ be the set of all probability vectors of length $N$ whose maximum index is a fixed target label $y\in\{1,\ldots, N\}$ .
Similarly, let $Q$ be the set of all probability vectors whose maximum index is \textbf{not} uniquely equal to $y$.
For instance, in the three-dimensional case with $y=3$
\begin{align*}
&P=\{(p_1,p_2,1-p_1-p_2) \text{ s.t. } |\ {\max(p_1,p_2) < 1-p_1-p_2}\}\\
&Q=\{(p_1,p_2,1-p_1-p_2) \text{ s.t. } |\ {\max(p_1,p_2) \geq 1-p_1-p_2}\}\\
&\text{and } 0\leq p_1,p_2\leq 1\land p_1+p_2\leq 1
\end{align*}
A score $S(\mathbf{p}; y)$ is proper for disagreement if
\begin{equation}
\forall\ \mathbf{p}\in P\ \min_{\mathbf{q}\in Q} S(\mathbf{q}; y) \leq S(\mathbf{p}; y)
\end{equation}
which is to say that for every probability vector in $P$ there exists a probability vector in $Q$ that achieves a score at least as low.
\end{definition}
\begin{manualtheorem}{3}[Validity of DCE]
The Disagreement Cross Entropy Loss is a valid disagreement loss function.
\end{manualtheorem}
\begin{proof}
We show that the minimum DCE for a probability vector $\mathbf{q}\in Q$ is $\log(N-1)$ while the minimum DCE for $\mathbf{p}\in P$ is $\log(N)$.
First note that by definition of the DCE in \autoref{eq:anitcross} the unique probability vector $\mathbf{v}$ that globally minimizes it with respect to a target class $y$ is
the vector with elements $1/(N-1)$ for all indices $\mathbf{v}_i$ where $i\neq y$ and $\mathbf{v}_y=0$. $\mathbf{v}$ is a member of $Q$ because it does not have a unique maximum at position $y$.
Computing the DCE of $\mathbf{v}$ with respect to the target $y$ gives:
\begin{align*}
\text{DCE}(\mathbf{v}, y) &= \frac{1}{1-N}\sum_{i=1}^N \log(\mathbf{v}_i) \mathbb{I}[i\neq y] \\
&= \frac{1}{1-N} \log\left(\frac{1}{N-1}\right) \times (N-1)\\
&=\log(N-1)
\end{align*}
Next note that since $\mathbf{p}_y$ does not contribute to minimizing the DCE which is in the form $\sum_{p\in \mathbf{p}\setminus \mathbf{p}_y} \log(p)$ the minimal solution must minimize the term $\mathbf{p}_y$. However to enforce $\mathbf{p}\in P$ we must have that $\forall\ i\in\{1,\ldots,N\}\setminus y,\ \mathbf{p}_y > \mathbf{p}_i$.
Hence, minimizing $\mathbf{p}_y$ while making sure $\mathbf{p}\in P$ enforces $\mathbf{p}_y > 1/N$ for some $\epsilon > 0$ and all other $p_i < 1/N$ since if $\mathbf{p}_y \leq 1/N$ we could also chose some $p_i \geq 1/N$ and produce a vector $\mathbf{p}\in Q$. Without loss of generality choosing $\mathbf{p}_y = 1/N +\varepsilon$ for some small $\varepsilon > 0$ we must chose all other $\mathbf{p}_i= 1/N - \epsilon_i$ s.t. all $\epsilon_i > 0$ and $\sum_{i\in \{1,\ldots, N\}\setminus y} \epsilon_i = \varepsilon$. This gives a DCE of:
\begin{align*}
&\text{DCE}(\mathbf{p}, y) = \frac{1}{1-N}\sum_{i=1}^N \log(\mathbf{p}_i) \mathbb{I}[i\neq y] = \frac{1}{1-N} \sum_{i\in \{1,\ldots, N\}\setminus y} \log\left(\frac{1}{N}-\epsilon_i\right)\\
&\text{ where } \lim_{\varepsilon\to 0 }\frac{1}{1-N} \sum_{i\in \{1,\ldots, N\}\setminus y} \log\left(\frac{1}{N}-\epsilon_i\right) = \log(N)
\end{align*}
For all vectors $\mathbf{p}\in P$ we cannot achieve a DCE of less than $\log(N)$ but the minimum DCE in $Q$ is $\log(N-1)$ which is less than $\log(N)$, hence proving that the DCE is a valid scoring rule for disagreement.
\end{proof}
\subsection{Implementation}
We provide a simple PyTorch style implementation for batched computation of the CDC objective in \autoref{eq:loss} given batch of logits and targets.
We assume \verb!logits! is a floating point vector of $(\text{batch\_size}\times\ N)$ logit values, \verb!targets! is a integer vector $(\text{batch\_size})$ of target classes, \verb!mask! is a 0-1 mask of $(\text{batch\_size})$ that is 1 at index $i$
if \verb!logits[i]! is $\boldP$ and 0 otherwise.
\begin{python}
import torch, import torch.nn.functional as F
def cdc_loss(logits, targets, mask, weight=1/(size of test set + 1)):
# select data in p
logits_p, targets_p = logits[mask], targets[mask]
# select data in Q
logits_q, targets_q = logits[1-mask], targets[1-mask]
# cross entropy on P with targets
loss_p = F.cross_entropy(logits_p, targets_p, reduction=None)
# DCE on Q
zero_hot = 1 - F.one_hot(targets_q, num_classes=N)
loss_q = (logits_q * zero_hot).sum(dim=1) / (1 - N)
+ torch.logsumexp(logits_q)
# Weighted mean
return torch.cat([loss_p, weight * loss_q]).mean()
\end{python}