diff --git a/timm/data/mixup.py b/timm/data/mixup.py index c8789a0c35..be0bae36c8 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -14,16 +14,16 @@ import torch -def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): +def one_hot(x, num_classes, on_value=1., off_value=0.): x = x.long().view(-1, 1) - return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) + return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value) -def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): +def mixup_target(target, num_classes, lam=1., smoothing=0.0): off_value = smoothing / num_classes on_value = 1. - smoothing + off_value - y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) - y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value) + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value) return y1 * lam + y2 * (1. - lam) @@ -214,7 +214,7 @@ def __call__(self, x, target): lam = self._mix_pair(x) else: lam = self._mix_batch(x) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) return x, target @@ -310,7 +310,7 @@ def __call__(self, batch, _=None): else: lam = self._mix_batch_collate(output, batch) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) target = target[:batch_size] return output, target