Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

Commit

Permalink
Use multiple target objectives for distillation. Also see cl/356382304
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 356382406
  • Loading branch information
Mesh TensorFlow Team committed Feb 9, 2021
1 parent 9625f34 commit 6d840eb
Showing 1 changed file with 45 additions and 8 deletions.
53 changes: 45 additions & 8 deletions mesh_tensorflow/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,10 @@ def __init__(self,
teacher,
temperature=None,
fraction_soft=None,
distill_start_step=0,
mse_coeff=0.,
kl_coeff=0.,
cosine_coeff=0.,
distill_start_steps=0,
teacher_checkpoint=None,
initialize_student_weights=False):
"""Create a StudentTeacher.
Expand All @@ -1664,7 +1667,10 @@ def __init__(self,
target cross entropy to the training loss. The rest of the loss will be
the cross entropy with the one-hot actual label. Required only when
training.
distill_start_step: an int, training steps after which teacher loss is
mse_coeff: MSE distillation loss co-efficient.
kl_coeff: KL-Divergence distillation loss co-efficient.
cosine_coeff: COsine-embedding distillation loss co-efficient.
distill_start_steps: an int, training steps after which teacher loss is
incorporated in the overall loss.
teacher_checkpoint: a string, the path to the teacher checkpoint that we
wish to use. Required only when training.
Expand All @@ -1676,9 +1682,15 @@ def __init__(self,
self.teacher = teacher
self.temperature = temperature
self.fraction_soft = fraction_soft
self.distill_start_step = distill_start_step
self.distill_start_steps = distill_start_steps
self.teacher_checkpoint = teacher_checkpoint
self.initialize_student_weights = initialize_student_weights
self.kl_coeff = kl_coeff
self.cosine_coeff = cosine_coeff
self.mse_coeff = mse_coeff
if (fraction_soft + kl_coeff + cosine_coeff + mse_coeff) > 1.:
raise ValueError("Distillation co-efficients must not add up to a value "
"greater than 1.")

def call_simple(self,
inputs,
Expand Down Expand Up @@ -1751,15 +1763,40 @@ def call_simple(self,
weights = mtf.cast(mtf.greater(targets, 0), soft_loss.dtype)
soft_loss = (mtf.reduce_sum(soft_loss * weights) /
self.student.loss_denominator(targets, num_microbatches))
if self.kl_coeff > 0.:
student_pred = mtf.softmax(student_logits / self.temperature,
output_vocab_dim)
kl_loss = mtf.layers.kl_divergence(
mtf.stop_gradient(soft_targets), student_pred, output_vocab_dim,
weights=weights)
else:
kl_loss = 0.
if self.cosine_coeff > 0.:
cosine_loss = mtf.layers.cosine_embedding_distill(
mtf.stop_gradient(teacher_logits), student_logits, output_vocab_dim,
weights=weights)
else:
cosine_loss = 0.
if self.mse_coeff > 0.:
mse_loss = mtf.layers.kl_divergence(
mtf.stop_gradient(teacher_logits), student_logits, output_vocab_dim,
weights=weights)
else:
mse_loss = 0.
global_step = tf.train.get_or_create_global_step()
current_fraction_soft = tf.cast(
distill_loss_fraction = (self.fraction_soft + self.kl_coeff +
self.mse_coeff + self.kl_coeff)
current_distill_fraction = tf.cast(
tf.cond(
tf.math.greater(global_step, self.distill_start_step),
lambda: self.fraction_soft, lambda: tf.constant(0.0)),
tf.math.greater(global_step, self.distill_start_steps),
lambda: distill_loss_fraction, lambda: tf.constant(0.0)),
dtype=tf.bfloat16)

loss = (1.0 - current_fraction_soft) * hard_loss \
+ self.temperature**2 * current_fraction_soft * soft_loss
loss = (1.0 - current_distill_fraction) * hard_loss \
+ current_distill_fraction * (
self.temperature**2 * soft_loss * self.fraction_soft +
self.kl_coeff * kl_loss + self.mse_coeff + mse_loss +
self.cosine_coeff * cosine_loss)

return student_logits, loss

Expand Down

0 comments on commit 6d840eb

Please sign in to comment.