-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathaug_clip.py
executable file
·101 lines (95 loc) · 2.95 KB
/
aug_clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Augmenting CLIP, 2021, by Peter Baylies (@pbaylies)
# Simple MLPs for training against LAION400m embeddings
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchvision
import shutil
def save_ckp(state, is_best, prefix=''):
f_path = prefix + '_checkpoint.pt'
torch.save(state, f_path)
if is_best:
best_fpath = prefix + '_best_model.pt'
shutil.copyfile(f_path, best_fpath)
in_size = 512
out_size = 512
device = torch.device('cuda')
model1 = nn.Sequential(
nn.Linear(in_size, in_size, True),
nn.ELU(),
nn.Linear(in_size, in_size, True),
nn.ELU(),
nn.Linear(in_size, in_size, True),
nn.ELU(),
nn.Linear(in_size, out_size, True),
nn.Tanh(),
).to(device)
model2 = nn.Sequential(
nn.Linear(in_size, in_size, True),
nn.ELU(),
nn.Linear(in_size, in_size, True),
nn.ELU(),
nn.Linear(in_size, in_size, True),
nn.ELU(),
nn.Linear(in_size, out_size, True),
nn.Tanh(),
).to(device)
#model = torch.load("checkpoint.pt")
batch_size = 256
lr = 1e-3
optim1 = torch.optim.AdamW(model1.parameters(), lr)
optim2 = torch.optim.AdamW(model2.parameters(), lr)
smoothl1 = nn.SmoothL1Loss()
count = 0
min_loss1 = 1000000
loss1 = min_loss1 - 1
min_loss2 = 1000000
loss2 = min_loss2 - 1
steps = 1000
for d in range(410):
inputs = torch.from_numpy(np.load('images/img_emb_%d.npy' % d)).float().to(device)
outputs = torch.from_numpy(np.load('texts/text_emb_%d.npy' % d)).float().to(device)
inputs /= inputs.norm(dim=-1, keepdim=True)
outputs /= outputs.norm(dim=-1, keepdim=True)
dataset = TensorDataset(inputs,outputs)
dataloader = DataLoader(dataset, batch_size=batch_size)
for input, target in dataloader:
count += 1
optim1.zero_grad()
output = model1(input)
loss1 = 1000*smoothl1(output, target)
loss1.backward()
optim1.step()
if count % steps == 0:
print(count)
is_best = loss1.sum() < min_loss1
save_ckp(model1, is_best, prefix='i2t')
if is_best:
min_loss1 = loss1
print('i2t: ' + str(min_loss1))
input, target = target, input
optim2.zero_grad()
output = model2(input)
loss2 = 1000*smoothl1(output, target)
loss2.backward()
optim2.step()
if count % steps == 0:
is_best = loss2.sum() < min_loss2
save_ckp(model2, is_best, prefix='t2i')
if is_best:
min_loss2 = loss2
print('t2i: ' + str(min_loss2))
if count % steps != 0:
print(count)
is_best = loss1 < min_loss1
save_ckp(model1, is_best, prefix='i2t')
if is_best:
min_loss1 = loss1
print('i2t: ' + str(min_loss1))
is_best = loss2 < min_loss2
save_ckp(model2, is_best, prefix='t2i')
if is_best:
min_loss2 = loss2
print('t2i: ' + str(min_loss2))