-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_utils.py
120 lines (94 loc) · 3.42 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
import matplotlib
matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
import os
def loadConfig(fn):
with open(fn) as f:
config = json.load(f)
return config
def showImgs(img_tensor):
detransform = transforms.ToPILImage()
img = img_tensor[0]
pil_img = detransform(img)
def tensorToCV_RGB(imgs_tensor):
# Copy tensor to cpu
imgs_tensor = imgs_tensor.cpu()
# Create PIL detransform
detransform = transforms.ToPILImage()
# Iterate image tensors, converting to CV images
cv_imgs = []
for tensor_img in imgs_tensor:
cv_rgb_img = np.array(detransform(tensor_img))
cv_imgs.append(cv_rgb_img)
return cv_imgs
def rgbToColorMask(rgb_imgs, config):
# Retrieve color ranges from config
lower_range = np.array(config["color_lower_range"])
upper_range = np.array(config["color_upper_range"])
# Create masks
masks = []
for rgb_img in rgb_imgs:
hsv_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2HSV)
mask = cv2.inRange(hsv_img, lower_range, upper_range)
masks.append(np.expand_dims(mask, axis=0))
# Normalize
masks = np.array(masks).astype(np.float32)
masks = np.divide(masks, 255.)
return masks
def plotHistory(history, savepath):
fig = plt.figure(figsize=(16, 10))
raw_loss = history["raw_loss"]
loss = history["loss"]
mask_loss = history["mask_loss"]
epochs = len(raw_loss)
plt.plot(np.arange(epochs, dtype=np.int32), loss)
plt.plot(np.arange(epochs, dtype=np.int32), raw_loss)
plt.plot(np.arange(epochs, dtype=np.int32), mask_loss)
plt.legend(["Weighted Loss", "Raw Loss", "Mask Loss"])
plt.xlabel("Epochs")
plt.ylabel("Loss")
filename = f"Graph_{epochs}.jpg"
path = os.path.join(savepath, filename)
plt.savefig(path)
plt.close()
class ImageFolderMask(Dataset):
def __init__(self, root_path, prefix="mask"):
self.raw_files = []
self.mask_files = []
self.prefix = prefix
for folder in os.listdir(root_path):
folder_path = os.path.join(root_path, folder)
for fn in os.listdir(folder_path):
fullPath = os.path.join(folder_path, fn)
if fn.startswith(prefix):
self.mask_files.append(fullPath)
else:
self.raw_files.append(fullPath)
assert len(self.raw_files) == len(self.mask_files)
self.raw_files = sorted(self.raw_files)
self.mask_files = sorted(self.mask_files)
def __len__(self):
return len(self.raw_files)
def __getitem__(self, index):
raw_path = self.raw_files[index]
mask_path = self.mask_files[index]
# Loading images as arrays with PIL
raw_img = np.array(Image.open(raw_path))
mask_img = np.expand_dims(np.array(Image.open(mask_path)), axis=-1)
# Normalize
raw_img = np.divide(raw_img, 255.0)
mask_img = np.divide(mask_img, 255.0)
# H, W, C -> C, H, W
raw_img = np.transpose(raw_img, (2, 0, 1))
mask_img = np.transpose(mask_img, (2, 0, 1))
# Converting to torch tensors
raw_tensor = torch.from_numpy(raw_img).float()
mask_tensor = torch.from_numpy(mask_img).float()
return raw_tensor, mask_tensor