-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathaugment.py
117 lines (94 loc) · 3.18 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from torchvision import transforms
from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
import numpy as np
from torchvision import datasets, transforms
import random
from PIL import ImageFilter, ImageOps
import torchvision.transforms.functional as TF
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
img = img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
return img
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class gray_scale(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
self.transf = transforms.Grayscale(3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
class horizontal_flip(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2,activate_pred=False):
self.p = p
self.transf = transforms.RandomHorizontalFlip(p=1.0)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
def aug_generator(args = None):
img_size = args.input_size
remove_random_resized_crop = args.src
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
primary_tfl = []
scale=(0.08, 1.0)
interpolation='bicubic'
if remove_random_resized_crop:
primary_tfl = [
transforms.Resize(img_size, interpolation=3),
transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),
transforms.RandomHorizontalFlip()
]
else:
primary_tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
secondary_tfl = []
if not args.disable_gray_solar_and_blur:
secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
Solarization(p=1.0),
GaussianBlur(p=1.0)])]
if args.color_jitter is not None and not args.color_jitter==0:
secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
final_tfl = [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)