-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgrl.py
88 lines (81 loc) · 3.03 KB
/
grl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from torch.autograd import Function
import glob
import random
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import logging
import parser
args = parser.parse_arguments()
'''
class GradientReversalFunction(Function):
@staticmethod
def forward(ctx, x, beta):
ctx.beta = beta
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.neg() * ctx.beta
return grad_input, None
'''
class GradientReversalFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grads):
dx = -grads.new_tensor(1) * grads
return dx, None
class GradientReversal(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.nn.functional.adaptive_avg_pool2d(x, (1,1))
x = x.view(x.shape[0], -1)
return GradientReversalFunction.apply(x)
def get_discriminator(input_dim, num_classes=2):
discriminator = nn.Sequential(
GradientReversal(),
nn.Linear(input_dim, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, num_classes)
)
return discriminator
grl_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class GrlDataset(torch.utils.data.Dataset):
def __init__(self, dataset_root, datasets_paths, length=1000000):
"""
datasets_paths is a list containing the folders which contain the N datasets.
__len__() returns 1000000, and __getitem__(index) returns a random
image, from dataset index % N, to ensure that each dataset has the
same chance of being picked
"""
super().__init__()
self.num_classes = len(datasets_paths)
logging.info(f"GrlDataset has {self.num_classes} classes")
self.images_paths = []
for dataset_path in datasets_paths:
self.images_paths.append(sorted(glob.glob(f"{dataset_root}/{dataset_path}/**/*.jpg", recursive=True)))
logging.info(f" Class {dataset_path} has {len(self.images_paths[-1])} images")
if len(self.images_paths[-1]) == 0:
raise Exception(f"Class {dataset_path} has 0 images, that's a problem!!!")
self.transform = grl_transform
self.length = length
def __getitem__(self, index):
num_class = index % self.num_classes
images_of_class = self.images_paths[num_class]
# choose a random one
image_path = random.choice(images_of_class)
tensor = self.transform(Image.open(image_path).convert("RGB"))
tensor = transforms.functional.resize(tensor, args.resize)
return tensor, num_class
def __len__(self):
return self.length