Skip to content

Commit

Permalink
Merge pull request huggingface#1222 from Leoooo333/master
Browse files Browse the repository at this point in the history
Fix mixup/one_hot device problem
  • Loading branch information
rwightman authored May 10, 2023
2 parents e0ec0f7 + fd592ec commit 8ce9a2c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions timm/data/mixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
import torch


def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
def one_hot(x, num_classes, on_value=1., off_value=0.):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)


def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
return y1 * lam + y2 * (1. - lam)


Expand Down Expand Up @@ -214,7 +214,7 @@ def __call__(self, x, target):
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
return x, target


Expand Down Expand Up @@ -310,7 +310,7 @@ def __call__(self, batch, _=None):
else:
lam = self._mix_batch_collate(output, batch)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
target = target[:batch_size]
return output, target

0 comments on commit 8ce9a2c

Please sign in to comment.