-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
69 lines (43 loc) · 1.51 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import numpy as np
from PIL import Image
from config import *
from collections import OrderedDict
from models import VGG, U2NET
def open_image(path):
return Image.open(path)
def load_image(img):
img = open_image(img).convert()
img = basic_transform(image=np.array(img))['image']
img = img.unsqueeze(0)
return img.to(DEVICE)
def load_essential_images(original_img, style_img):
original_img = load_image(original_img)
style_img = load_image(style_img)
generated = original_img.clone().requires_grad_(True)
return original_img, style_img, generated
def gram_matrix(feature_map):
return feature_map.mm(feature_map.t())
def min_max_normalization(d):
return (d - torch.min(d)) / (torch.max(d) - torch.min(d))
def checkpoint_exists(filename):
return filename in os.listdir(CHECKPOINT_DIR)
def load_vgg_model():
return VGG().to(DEVICE).eval()
def load_u2net_model(in_ch=3, out_ch=4, checkpoint=U2NET_CLOTHES_CHECKPOINT_FILE, ordered_dict=True):
model = U2NET(in_ch, out_ch).to(DEVICE)
if checkpoint_exists(checkpoint):
checkpoint = torch.load(
os.path.join(CHECKPOINT_DIR, checkpoint),
map_location=DEVICE
)
if ordered_dict:
state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:]
state_dict[name] = v
model.load_state_dict(state_dict)
else:
model.load_state_dict(checkpoint)
model.eval()
return model