Skip to content

Commit

Permalink
blblbl
Browse files Browse the repository at this point in the history
  • Loading branch information
filippo-merlo committed Feb 6, 2024
1 parent 5abce49 commit db8ab5c
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 35 deletions.
6 changes: 3 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@

# paths and filenames
bn_n_train = "bn_n_train.txt" # 23 attrs, -9 combs
bsn_novel_train_1 = "bsn_novel_train_1.txt" #
bsn_novel_train_1 = "bsn_novel_train_1.txt" # 20 attrs, -9 combs
bsn_novel_train_2 = "bsn_novel_train_2.txt" #
bsn_novel_train_2_nw = "bsn_novel_train_2_nw.txt" #
bsn_novel_train_2_old = "bsn_novel_train_2_old.txt" #

bn_n_test = "bn_n_test.txt" # 23 attrs, all combs
bsn_novel_test_1 = "bsn_novel_test_1.txt"
bsn_novel_test_1 = "bsn_novel_test_1.txt"
bsn_novel_test_2_nw = "bsn_novel_test_2_nw.txt"
bsn_novel_test_2_old = "bsn_novel_test_2_old.txt"

Expand All @@ -92,7 +92,7 @@
# train parameters
resize = 224
lr = 1e-3
epochs = 5
epochs = 1#5

sim_batch = 128
gen_batch = 128
Expand Down
180 changes: 148 additions & 32 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,35 @@ def my_train_clip_encoder(dt, memory, attr, lesson):
loss = 10
ct = 0
centroid_sim = torch.rand(1, latent_dim).to(device)
while loss > 0.008:
ct += 1
if ct > 5:
break
for i in range(200):
# Get Inputs: sim_batch, (sim_batch, 4, 128, 128)
base_name_sim, images_sim = dt.get_better_similar(attr, lesson)
images_sim = images_sim.to(device)

# run similar model
z_sim = model(clip_model, images_sim)
centroid_sim = centroid_sim.detach()
centroid_sim, loss_sim = get_sim_loss(torch.vstack((z_sim, centroid_sim)))

# Run Difference
base_name_dif, images_dif = dt.get_better_similar_not(attr, lesson)
images_dif = images_dif.to(device)

# run difference model
z_dif = model(clip_model, images_dif)
loss_dif = get_sim_not_loss(centroid_sim, z_dif)

# compute loss
loss = (loss_sim)**2 + (loss_dif-1)**2
optimizer.zero_grad()
loss.backward()
optimizer.step()

#while loss > 0.008:
##
ct += 1
#if ct > 5:
# break
for i in range(1):#200):
# Get Inputs: sim_batch, (sim_batch, 4, 128, 128)
base_name_sim, images_sim = dt.get_better_similar(attr, lesson)
images_sim = images_sim.to(device)

# run similar model
z_sim = model(clip_model, images_sim)
centroid_sim = centroid_sim.detach()
centroid_sim, loss_sim = get_sim_loss(torch.vstack((z_sim, centroid_sim)))

# Run Difference
base_name_dif, images_dif = dt.get_better_similar_not(attr, lesson)
images_dif = images_dif.to(device)

# run difference model
z_dif = model(clip_model, images_dif)
loss_dif = get_sim_not_loss(centroid_sim, z_dif)

# compute loss
loss = (loss_sim)**2 + (loss_dif-1)**2
optimizer.zero_grad()
loss.backward()
optimizer.step()
##
print('[', ct, ']', loss.detach().item(), loss_sim.detach().item(),
loss_dif.detach().item())

Expand Down Expand Up @@ -149,6 +150,75 @@ def my_clip_evaluation(in_path, source, memory, in_base, types, dic, vocab):
top3_shape/tot_num, top3/tot_num)
return top3/tot_num

def memory_evaluation(in_path, source, memory, in_base, types, dic, vocab):
with torch.no_grad():

# get dataset
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
dt = MyDataset(in_path, source, in_base, types, dic, vocab,
clip_preprocessor=clip_preprocess)
data_loader = DataLoader(dt, batch_size=128, shuffle=True)

top3 = 0
top3_color = 0
top3_material = 0
top3_shape = 0
tot_num = 0

for base_is, images in data_loader:
# Prepare the inputs
images = images.to(device)
ans = []
batch_size_i = len(base_is)

# go through memory
for label in vocab:
if label not in memory.keys():
ans.append(torch.full((batch_size_i, 1), 1000.0).squeeze(1))
continue

# load model
model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False)
model.load_state_dict(memory[label]['model'])
model.to(device)
model.eval()

# load centroid
centroid_i = memory[label]['centroid'].to(device)
centroid_i = centroid_i.repeat(batch_size_i, 1)

# compute stats
z = model(clip_model, images).squeeze(0)
disi = ((z - centroid_i)**2).mean(dim=1)
ans.append(disi.detach().to('cpu'))

# get top3 incicies
ans = torch.stack(ans, dim=1)
values, indices = ans.topk(3, largest=False)
_, indices_lb = base_is.topk(3)
indices_lb, _ = torch.sort(indices_lb)

# calculate stats
tot_num += len(indices)
for bi in range(len(indices)):
ci = 0
mi = 0
si = 0
if indices_lb[bi][0] in indices[bi]:
ci = 1
if indices_lb[bi][1] in indices[bi]:
mi = 1
if indices_lb[bi][2] in indices[bi]:
si = 1

top3_color += ci
top3_material += mi
top3_shape += si
if (ci == 1) and (mi == 1) and (si == 1):
top3 += 1

return tot_num, top3/tot_num, top3_color/tot_num, top3_material/tot_num, top3_shape/tot_num


def my_clip_train(in_path, out_path, model_name, source, in_base,
types, dic, vocab, pre_trained_model=None):
Expand Down Expand Up @@ -182,7 +252,7 @@ def my_clip_train(in_path, out_path, model_name, source, in_base,

# evaluate
top_nt = my_clip_evaluation(in_path, 'novel_test/', memory,
bsn_novel_test_1, ['rgba'], dic_train, vocab)
bn_n_test, ['rgba'], dic_train, vocab)
if top_nt > best_nt:
best_nt = top_nt
print("++++++++++++++ BEST NT: " + str(best_nt))
Expand All @@ -196,11 +266,57 @@ def my_clip_train(in_path, out_path, model_name, source, in_base,
help='Data input path', required=True)
argparser.add_argument('--out_path', '-o',
help='Model memory output path', required=True)
argparser.add_argument('--model_name', '-n', default='best_mem.pickle',
help='Best model memory to be saved file name', required=False)
argparser.add_argument('--pre_train', '-p', default=None,
help='Pretrained model import name (saved in outpath)', required=False)
args = argparser.parse_args()

my_clip_train(args.in_path, args.out_path, args.model_name,
evaluations = {
'mare_novel_comp': [],
'mare_var': []
}

# Train without 9 combs and 3 attrs
model_name = 'comparative_base_mem.pickle'
my_clip_train(args.in_path, args.out_path, model_name,
'novel_train/', bn_n_train, ['rgba'], dic_train, vocabs, args.pre_train)

# Multi-Attribute Recognition Evaluation
in_memory = os.path.join(args.out_path, model_name)
infile = open(in_memory, 'rb')
memory = pickle.load(infile)
infile.close()

# evaluate with novel compositions
evaluations['mare_novel_comp'].append(memory_evaluation(args.in_path, 'novel_test/', memory,
bn_n_test, ['rgba'], dic_train, vocabs))

# evaluate with variations (dic_test)
evaluations['mare_var'].append(memory_evaluation(args.in_path, 'novel_test/', memory,
bn_n_train, ['rgba'], dic_test, vocabs))

## Train for new word acquisition
## without
#model_name = 'nw_acquisition_old.pickle'
#my_clip_train(args.in_path, args.out_path, model_name,
# 'novel_train/', bsn_novel_train_1, ['rgba'], dic_train, vocabs, args.pre_train)
#
#in_memory = os.path.join(args.out_path, model_name)
#infile = open(in_memory, 'rb')
#memory = pickle.load(infile)
#infile.close()
#
#evaluations['mare_new_word_old_new_comp'].append(memory_evaluation(args.in_path, 'novel_test/', memory,
# bsn_novel_test_1, ['rgba'], dic_train, vocabs))
#
#evaluations['mare_new_word_old_new_comp_var'].append(memory_evaluation(args.in_path, 'novel_test/', memory,
# bsn_novel_test_1, ['rgba'], dic_test, vocabs))
#
## with new word
#model_name = 'nw_acquisition_new.pickle'
#my_clip_train(args.in_path, args.out_path, model_name,
# 'novel_train/', bsn_novel_train_2_old, ['rgba'], dic_train, vocabs, pre_train = 'nw_acquisition_old.pickle')
#
#in_memory = os.path.join(args.out_path, model_name)
#infile = open(in_memory, 'rb')
#memory = pickle.load(infile)
#infile.close()

0 comments on commit db8ab5c

Please sign in to comment.