diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..54e3650 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/__pycache__/config.cpython-310.pyc b/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..f6f8da6 Binary files /dev/null and b/__pycache__/config.cpython-310.pyc differ diff --git a/__pycache__/dataset.cpython-310.pyc b/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000..809c365 Binary files /dev/null and b/__pycache__/dataset.cpython-310.pyc differ diff --git a/__pycache__/util.cpython-310.pyc b/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..b946e22 Binary files /dev/null and b/__pycache__/util.cpython-310.pyc differ diff --git a/code/__pycache__/config.cpython-310.pyc b/code/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..3922a78 Binary files /dev/null and b/code/__pycache__/config.cpython-310.pyc differ diff --git a/code/__pycache__/dataset.cpython-310.pyc b/code/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000..752b0fa Binary files /dev/null and b/code/__pycache__/dataset.cpython-310.pyc differ diff --git a/code/__pycache__/models.cpython-310.pyc b/code/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000..70c553a Binary files /dev/null and b/code/__pycache__/models.cpython-310.pyc differ diff --git a/code/__pycache__/util.cpython-310.pyc b/code/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..0ca94b9 Binary files /dev/null and b/code/__pycache__/util.cpython-310.pyc differ diff --git a/code/best_mem.pickle b/code/best_mem.pickle new file mode 100755 index 0000000..77d51ce Binary files /dev/null and b/code/best_mem.pickle differ diff --git a/code/config.py b/code/config.py new file mode 100644 index 0000000..640a18d --- /dev/null +++ b/code/config.py @@ -0,0 +1,144 @@ +#%% +''' +Learning Attributes: + - Color (6) + - Material (3) + - Shape (8) +Additional Attributes: + - Color (2) + - Material (1) + - Shape (3) +Flexibility: + - Camera angle (6) + - Lighting (3) +Variability (Only in testing): + - Size (2) [Default: large] + - Stretch (3) [Default: normal] + - Color shade (2) [Default: base] + +Naming convension: +[color]_[material]_[shape]_shade_[]_stretch_[]_scale_[]_brightness_view_[]_[tyimg].png +e.g. +aqua_glass_cone_shade_base_stretch_normal_scale_large_brightness_bright_view_0_-2_3_rgba.png +''' +# Learning attributes: +colors = ['brown', "green", "blue", "aqua", "purple", "red", "yellow", 'white'] +materials = ['rubber', 'metal', 'plastic', 'glass'] +shapes = ["cube", "cylinder", "sphere", "cone", "torus", "gear", + "torus_knot", "sponge", "spot", "teapot", "suzanne"] +vocabs = colors+materials+shapes + + +# Flexibility: +views = ['0_3_2', '-2_-2_2', '-2_2_2', '1.5_-1.5_3', '1.5_1.5_3', '0_-2_3'] +brightness = ['dim', 'normal', 'bright'] + +# Variability +scale_train = ['large'] +stretch_train = ['normal'] +shade_train = ['base'] + +scale_test = ['small', 'medium', 'large'] +stretch_test = ['normal', 'x', 'y', 'z'] +shade_test = ['base', 'light', 'dark'] + +others = views + brightness + scale_test + stretch_test + shade_test + +# Types of images +tyimgs = ['rgba', 'depth', 'normal', 'object_coordinates', 'segmentation'] + + +dic_train = {"color": colors, + "material": materials, + "shape": shapes, + "view": views, + 'brightness': brightness, + "scale": scale_train, + 'stretch': stretch_train, + 'shade': shade_train + } +dic_test = {"color": colors, + "material": materials, + "shape": shapes, + "view": views, + 'brightness': brightness, + "scale": scale_test, + 'stretch': stretch_test, + 'shade': shade_test + } + +types_learning = ['color', 'material', 'shape'] +types_flebility = ['color', 'material', 'shape', 'brightness', 'view'] +types_variability = ['scale', 'stretch', 'shade'] +types_all = ['color', 'material', 'shape', 'brightness', + 'view', 'shade', 'stretch', 'scale'] + +# make dicts for logical traing and testing +relations = ['and', 'or', 'not'] # <--- new +types_logical = [] +for i in types_learning: + for j in relations: + if j == 'not': + types_logical.append(j+' '+i) + else: + for h in types_learning: + if h+' '+j+' '+i not in types_logical: + if j == 'and' and i == h: + pass + else: + types_logical.append(i+' '+j+' '+h) +types_logical_with_learning = types_learning + types_logical + +from itertools import product +from pprint import pprint +dic_train_logical = dic_train.copy() +for rel in types_logical: + if rel.split(' ')[0] == 'not': + attr = rel.split(' ')[1] + dic_train_logical[rel] = [f'not {x}' for x in dic_train[attr]] + else: + attr1 = rel.split(' ')[0] + r = rel.split(' ')[1] + attr2 = rel.split(' ')[2] + dic_train_logical[rel] = [f'{x} {r} {y}' for x, y in product(dic_train[attr1], dic_train[attr2]) if x != y] + +all_vocabs = [] +for v in dic_train_logical.values(): + for n in v: + if n not in others: + all_vocabs.append(n) + +#print(all_vocabs) +#pprint(dic_train_logical) +#print(types_logical_with_learning) + +# paths and filenames +bn_n_train = "bn_n_train.txt" +bsn_novel_train_1 = "bsn_novel_train_1.txt" +bsn_novel_train_2 = "bsn_novel_train_2.txt" +bsn_novel_train_2_nw = "bsn_novel_train_2_nw.txt" +bsn_novel_train_2_old = "bsn_novel_train_2_old.txt" + +bn_n_test = "bn_n_test.txt" +bsn_novel_test_1 = "bsn_novel_test_1.txt" +bsn_novel_test_2_nw = "bsn_novel_test_2_nw.txt" +bsn_novel_test_2_old = "bsn_novel_test_2_old.txt" + +bn_train = "bn_train.txt" +bn_test = "bn_test.txt" +bsn_test_1 = "bsn_test_1.txt" +bsn_test_2_nw = "bsn_test_2_nw.txt" +bsn_test_2_old = "bsn_test_2_old.txt" + +# train parameters +resize = 224 +lr = 1e-3 +epochs = 50 + +sim_batch = 132 +gen_batch = 132 +batch_size = 33 + +# model architecture +hidden_dim_clip = 128 +latent_dim = 16 diff --git a/code/dataset.py b/code/dataset.py new file mode 100644 index 0000000..a939dd8 --- /dev/null +++ b/code/dataset.py @@ -0,0 +1,288 @@ +import os +import torch +import random +from PIL import Image +import torch.nn.functional as F +import torchvision.transforms as TT +from torch.utils.data.dataset import Dataset + +from config import * +from util import * + + +class MyDataset(): + def __init__(self, in_path, source, in_base, types, + dic, vocab, clip_preprocessor=None): + self.dic = dic + self.source = source + self.types = types + self.in_path = in_path + self.totensor = TT.ToTensor() + self.resize = TT.Resize((resize, resize)) + self.clip_preprocessor = clip_preprocessor + + # convert vocab list to dic + self.vocab = vocab + self.vocab_nums = {xi: idx for idx, xi in enumerate(self.vocab)} + + # Get list of test images + self.names_list = [] + with open(os.path.join(self.in_path, 'names', in_base)) as f: + lines = f.readlines() + for line in lines: + self.names_list.append(line[:-1]) + + self.name_set = set(self.names_list) + + def __len__(self): + return len(self.names_list) + + # only for CLIP emb + def __getitem__(self, idx): + base_name = self.names_list[idx] + image = self.img_emb(base_name) + + # get label indicies + nm = pareFileNames(base_name) + num_labels = [self.vocab_nums[li] for li in [nm['color'], + nm['material'], nm['shape']]] + + # turn num_labels into one-hot + labels = torch.zeros(len(self.vocab)) + for xi in num_labels: + labels[xi] = 1 + + return labels, image + + def img_emb(self, base_name): + # get names + names = [] + for tp in self.types: + names.append(os.path.join(self.in_path, self.source, + base_name + '_' + tp + '.png')) + + # if clip preprocess + if self.clip_preprocessor is not None: + images = self.clip_preprocessor(Image.open(names[0])) + return images + + # preprocess images + images = [] + for ni in range(len(names)): + input_image = Image.open(names[ni]).convert('RGB') + input_image = self.totensor(input_image) + + if names[ni][-16:] == "segmentation.png": + input_image = input_image.sum(dim=0) + vals_seg = torch.unique(input_image) + seg_map = [] + + # generate one hot segmentation mask + for i in range(len(vals_seg)): + mask = input_image.eq(vals_seg[i]) + # hack: only keep the non-background segmentation masks + if mask[0][0] is True: + continue + seg_mapi = torch.zeros([input_image.shape[0], + input_image.shape[1]]).masked_fill_(mask, 1) + seg_map.append(seg_mapi) + + seg_map = torch.cat(seg_map).unsqueeze(0) + images.append(seg_map) + else: + images.append(input_image) + + images[ni] = self.resize(images[ni]) + + # (d, resize, resize), d = 3 + #objs (+ other img types *3) + images = torch.cat(images) + return images + + def get_paired_batches(self, attribute, lesson): + base_names_sim = [] + base_names_dif = [] + images_sim = [] + images_dif = [] + if 'and' in attribute.split() or 'or' in attribute.split(): + lesson = (lesson.split()[0], lesson.split()[2]) + elif 'not' in attribute.split(): + lesson = lesson.split()[1] + + if ' ' not in attribute: # if the attribute is not logical + while len(base_names_sim) < sim_batch: #133 + names_dic_sim = {} + names_dic_dif = {} + + for k, v in self.dic.items(): # iterate on 'attribute_type':[list of attributes] + if ' ' not in k: # if the attribute is not logical + if k == attribute: # if the attribute is the one we want to teach e.g. color + names_dic_sim[k] = lesson # we take the lesson e.g. red + tp = random.choice(v) + while (tp == lesson): + tp = random.choice(v) + names_dic_dif[k] = tp + else: + tpo = random.choice(v) # we take a random attribute from the list of attributes e.g. blue + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = f'{names_dic_sim["color"]}_{names_dic_sim["material"]}_{names_dic_sim["shape"]}_shade_{names_dic_sim["shade"]}_stretch_{names_dic_sim["stretch"]}_scale_{names_dic_sim["scale"]}_brightness_{names_dic_sim["brightness"]}_view_{names_dic_sim["view"]}' # we create the name of the image from the dict + base_name_dif = f'{names_dic_dif["color"]}_{names_dic_dif["material"]}_{names_dic_dif["shape"]}_shade_{names_dic_dif["shade"]}_stretch_{names_dic_dif["stretch"]}_scale_{names_dic_dif["scale"]}_brightness_{names_dic_dif["brightness"]}_view_{names_dic_dif["view"]}' # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + + else: # if the attribute is logical + while len(base_names_sim) < sim_batch: #133 + names_dic_sim = {} + names_dic_dif = {} + if 'and' in attribute.split(): + attribute1 = attribute.split()[0] + attribute2 = attribute.split()[2] + for negative_case in range(3): # 0,1,2 [negatives] + for k, v in self.dic.items(): # iterate on 'attribute_type':[list of attributes] + if ' ' not in k: # if the attribute is not logical + if k == attribute1: + names_dic_sim[k] = lesson[0] + if negative_case == 0: + names_dic_dif[k] = lesson[0] + else: + tp = random.choice(v) + while (tp == lesson[0]): + tp = random.choice(v) + names_dic_dif[k] = tp + elif k==attribute2: + names_dic_sim[k] = lesson[1] + if negative_case == 1: + names_dic_dif[k] = lesson[1] + else: + tp = random.choice(v) + while (tp == lesson[1]): + tp = random.choice(v) + names_dic_dif[k] = tp + else: + tpo = random.choice(v) # we take a random attribute from the list of attributes e.g. blue + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = f'{names_dic_sim["color"]}_{names_dic_sim["material"]}_{names_dic_sim["shape"]}_shade_{names_dic_sim["shade"]}_stretch_{names_dic_sim["stretch"]}_scale_{names_dic_sim["scale"]}_brightness_{names_dic_sim["brightness"]}_view_{names_dic_sim["view"]}' # we create the name of the image from the dict + base_name_dif = f'{names_dic_dif["color"]}_{names_dic_dif["material"]}_{names_dic_dif["shape"]}_shade_{names_dic_dif["shade"]}_stretch_{names_dic_dif["stretch"]}_scale_{names_dic_dif["scale"]}_brightness_{names_dic_dif["brightness"]}_view_{names_dic_dif["view"]}' # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + + elif 'or' in attribute.split(): + attribute1 = attribute.split()[0] + attribute2 = attribute.split()[2] + if attribute1 == attribute2: + for negative_case in range(2): # 0,1 [negatives] + for k, v in self.dic.items(): # iterate on 'attribute_type':[list of attributes] + + if ' ' not in k: # if the attribute is not logical + if k == attribute1: + tp = random.choice(v) + while (tp == lesson[0] or tp == lesson[1]): + tp = random.choice(v) + names_dic_dif[k] = tp + + if negative_case == 0: + names_dic_sim[k] = lesson[0] + if negative_case == 1: + names_dic_sim[k] = lesson[1] + else: + tpo = random.choice(v) # we take a random attribute from the list of attributes e.g. blue + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = f'{names_dic_sim["color"]}_{names_dic_sim["material"]}_{names_dic_sim["shape"]}_shade_{names_dic_sim["shade"]}_stretch_{names_dic_sim["stretch"]}_scale_{names_dic_sim["scale"]}_brightness_{names_dic_sim["brightness"]}_view_{names_dic_sim["view"]}' # we create the name of the image from the dict + base_name_dif = f'{names_dic_dif["color"]}_{names_dic_dif["material"]}_{names_dic_dif["shape"]}_shade_{names_dic_dif["shade"]}_stretch_{names_dic_dif["stretch"]}_scale_{names_dic_dif["scale"]}_brightness_{names_dic_dif["brightness"]}_view_{names_dic_dif["view"]}' # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + else: + for negative_case in range(3): # 0,1,2 [negatives] + for k, v in self.dic.items(): # iterate on 'attribute_type':[list of attributes] + + if ' ' not in k: # if the attribute is not logical + if k == attribute1: + tp = random.choice(v) + while (tp == lesson[0]): + tp = random.choice(v) + names_dic_dif[k] = tp + + if negative_case == 0 or negative_case == 1: + names_dic_sim[k] = lesson[0] + else: + names_dic_sim[k] = names_dic_dif[k] + + elif k==attribute2: + tp = random.choice(v) + while (tp == lesson[1]): + tp = random.choice(v) + names_dic_dif[k] = tp + + if negative_case == 0 or negative_case == 2: + names_dic_sim[k] = lesson[1] + else: + names_dic_sim[k] = names_dic_dif[k] + else: + tpo = random.choice(v) # we take a random attribute from the list of attributes e.g. blue + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = f'{names_dic_sim["color"]}_{names_dic_sim["material"]}_{names_dic_sim["shape"]}_shade_{names_dic_sim["shade"]}_stretch_{names_dic_sim["stretch"]}_scale_{names_dic_sim["scale"]}_brightness_{names_dic_sim["brightness"]}_view_{names_dic_sim["view"]}' # we create the name of the image from the dict + base_name_dif = f'{names_dic_dif["color"]}_{names_dic_dif["material"]}_{names_dic_dif["shape"]}_shade_{names_dic_dif["shade"]}_stretch_{names_dic_dif["stretch"]}_scale_{names_dic_dif["scale"]}_brightness_{names_dic_dif["brightness"]}_view_{names_dic_dif["view"]}' # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + + elif 'not' in attribute.split(): + attribute1 = attribute.split()[1] + for k, v in self.dic.items(): # iterate on 'attribute_type':[list of attributes] + if ' ' not in k: # if the attribute is not logical + if k == attribute1: # if the attribute is the one we want to teach e.g. color + names_dic_dif[k] = lesson # we take the lesson e.g. red + tp = random.choice(v) + while (tp == lesson): + tp = random.choice(v) + names_dic_sim[k] = tp + else: + tpo = random.choice(v) # we take a random attribute from the list of attributes e.g. plastic + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = f'{names_dic_sim["color"]}_{names_dic_sim["material"]}_{names_dic_sim["shape"]}_shade_{names_dic_sim["shade"]}_stretch_{names_dic_sim["stretch"]}_scale_{names_dic_sim["scale"]}_brightness_{names_dic_sim["brightness"]}_view_{names_dic_sim["view"]}' # we create the name of the image from the dict + base_name_dif = f'{names_dic_dif["color"]}_{names_dic_dif["material"]}_{names_dic_dif["shape"]}_shade_{names_dic_dif["shade"]}_stretch_{names_dic_dif["stretch"]}_scale_{names_dic_dif["scale"]}_brightness_{names_dic_dif["brightness"]}_view_{names_dic_dif["view"]}' # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + + images_sim = torch.stack(images_sim) + images_dif = torch.stack(images_dif) + + return base_names_sim, images_sim, base_names_dif, images_dif diff --git a/code/main.py b/code/main.py new file mode 100644 index 0000000..cdde835 --- /dev/null +++ b/code/main.py @@ -0,0 +1,205 @@ +import os +import torch +import clip +import time +import pickle +import random +import argparse +import torch.nn as nn +import torch.optim as optim +from PIL import Image + +from torch.utils.data import DataLoader + +from config import * +from dataset import * +from models import * + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def my_train_clip_encoder(dt, memory, attr, lesson): + # get model + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False) + if lesson in memory.keys(): + print("______________ loading_____________________") + model.load_state_dict(memory[lesson]['model']) + optimizer = optim.Adam(model.parameters(), lr=lr) + model.train().to(device) + + loss_sim = None + loss_dif = None + loss = 10 + ct = 0 + centroid_sim = torch.rand(1, latent_dim).to(device) + while loss > 0.008: + ct += 1 + if ct > 5: + break + for i in range(200): + # Get Inputs: sim_batch, (sim_batch, 4, 132, 132) + names_sim, images_sim, names_dif, images_dif = dt.get_paired_batches(attr,lesson) + images_sim = images_sim.to(device) + + # run similar model + z_sim = model(clip_model, images_sim) + centroid_sim = centroid_sim.detach() + centroid_sim, loss_sim = get_sim_loss(torch.vstack((z_sim, centroid_sim))) + + # Run Difference + images_dif = images_dif.to(device) + + # run difference model + z_dif = model(clip_model, images_dif) + loss_dif = get_sim_not_loss(centroid_sim, z_dif) + + # compute loss + loss = (loss_sim)**2 + (loss_dif-1)**2 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + print('[', ct, ']', loss.detach().item(), loss_sim.detach().item(), + loss_dif.detach().item()) + + ############ save model ######### + with torch.no_grad(): + memory[lesson] = {'model': model.to('cpu').state_dict(), + 'arch': ['Filter', ['para_block1']], + 'centroid': centroid_sim.to('cpu') + } + return memory + + +def my_clip_evaluation(in_path, source, memory, in_base, types, dic, vocab): + with torch.no_grad(): + # get vocab dictionary + if source == 'train': + dic = dic_test + else: + dic = dic_train + + # get dataset + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + dt = MyDataset(in_path, source, in_base, types, dic, vocab, + clip_preprocessor=clip_preprocess) + data_loader = DataLoader(dt, batch_size=132, shuffle=True) + + top3 = 0 + top3_color = 0 + top3_material = 0 + top3_shape = 0 + tot_num = 0 + + for base_is, images in data_loader: # labels (one hot), images (clip embs) + # Prepare the inputs + images = images.to(device) + ans = [] + batch_size_i = len(base_is) + + # go through memory + for label in vocab: # select a label es 'red' + if label not in memory.keys(): + ans.append(torch.full((batch_size_i, 1), 1000.0).squeeze(1)) + continue + + # load model + model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False) + model.load_state_dict(memory[label]['model']) # load weights corresponding to red + model.to(device) + model.eval() # freeze + + # load centroid + centroid_i = memory[label]['centroid'].to(device) + centroid_i = centroid_i.repeat(batch_size_i, 1) + + # compute stats + z = model(clip_model, images).squeeze(0) + disi = ((z - centroid_i)**2).mean(dim=1) + ans.append(disi.detach().to('cpu')) + + # get top3 indices + ans = torch.stack(ans, dim=1) + values, indices = ans.topk(3, largest=False) + _, indices_lb = base_is.topk(3) + indices_lb, _ = torch.sort(indices_lb) + + # calculate stats + tot_num += len(indices) + for bi in range(len(indices)): + ci = 0 + mi = 0 + si = 0 + if indices_lb[bi][0] in indices[bi]: + ci = 1 + if indices_lb[bi][1] in indices[bi]: + mi = 1 + if indices_lb[bi][2] in indices[bi]: + si = 1 + + top3_color += ci + top3_material += mi + top3_shape += si + if (ci == 1) and (mi == 1) and (si == 1): + top3 += 1 + + print(tot_num, top3_color/tot_num, top3_material/tot_num, + top3_shape/tot_num, top3/tot_num) + return top3/tot_num + + +def my_clip_train(in_path, out_path, model_name, source, in_base, + types, dic, vocab, pre_trained_model=None): + # get data + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + dt = MyDataset(in_path, source, in_base, types, dic, vocab, + clip_preprocessor=clip_preprocess) + + # load encoder models from memory + memory = {} + if pre_trained_model is not None: + print(">>>>> loading memory >>>>>") + in_memory = os.path.join(out_path, pre_trained_model) + infile = open(in_memory, 'rb') + memory = pickle.load(infile) + infile.close() + + best_nt = 0 + t_tot = 0 + for i in range(epochs): + for tl in types_logical_with_learning: # attr + random.shuffle(dic[tl]) + for vi in dic[tl]: # lesson + print("#################### Learning: " + str(i) + " ----- " + str(vi)) + t_start = time.time() + memory = my_train_clip_encoder(dt, memory, tl, vi) + t_end = time.time() + t_dur = t_end - t_start + t_tot += t_dur + print("Time: ", t_dur, t_tot) + + # evaluate + top_nt = my_clip_evaluation(in_path, 'novel_test/', memory, + bsn_novel_test_1, ['rgba'], dic_train, vocab) + if top_nt > best_nt: + best_nt = top_nt + print("++++++++++++++ BEST NT: " + str(best_nt)) + with open(os.path.join(out_path, model_name), 'wb') as handle: + pickle.dump(memory, handle, protocol=pickle.HIGHEST_PROTOCOL) + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument('--in_path', '-i', + help='Data input path', required=True) + argparser.add_argument('--out_path', '-o', + help='Model memory output path', required=True) + argparser.add_argument('--model_name', '-n', default='best_mem.pickle', + help='Best model memory to be saved file name', required=False) + argparser.add_argument('--pre_train', '-p', default=None, + help='Pretrained model import name (saved in outpath)', required=False) + args = argparser.parse_args() + + my_clip_train(args.in_path, args.out_path, args.model_name, + 'novel_train/', bn_n_train, ['rgba'], dic_train_logical, all_vocabs, args.pre_train) diff --git a/code/models.py b/code/models.py new file mode 100644 index 0000000..db0218a --- /dev/null +++ b/code/models.py @@ -0,0 +1,43 @@ +'''ResNet in PyTorch. +For Pre-activation ResNet, see 'preact_resnet.py'. +Credit: https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py +VAE Credit: https://github.com/AntixK/PyTorch-VAE/tree/a6896b944c918dd7030e7d795a8c13e5c6345ec7 +Contrastive Loss: https://lilianweng.github.io/posts/2021-05-31-contrastive/ +CLIP train: https://github.com/openai/CLIP/issues/83 + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import clip +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + + +from config import * +device = "cuda" if torch.cuda.is_available() else "cpu" + + +class CLIP_AE_Encode(nn.Module): + def __init__(self, hidden_dim, latent_dim, isAE=False): + super(CLIP_AE_Encode, self).__init__() + # Build Encoder + self.fc1 = nn.Linear(512, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, latent_dim) + self.relu = nn.ReLU(inplace=True) + + if isAE: + self.filter = nn.Parameter(torch.ones((512))) + else: + self.filter = nn.Parameter(torch.rand((512))) + + def forward(self, clip_model, images): + with torch.no_grad(): + emb = clip_model.encode_image(images).float() + out = emb * self.filter + out = self.relu(self.fc1(out)) + z = self.fc2(out) + + return z diff --git a/code/open_mem.py b/code/open_mem.py new file mode 100644 index 0000000..7686505 --- /dev/null +++ b/code/open_mem.py @@ -0,0 +1,121 @@ +#%% +import pickle +import os +import torch +import clip +import time +import pickle +import random +import argparse +import torch.nn as nn +import torch.optim as optim +from PIL import Image + +from torch.utils.data import DataLoader + +from config import * +from dataset import * +from models import * + +device = "cuda" if torch.cuda.is_available() else "cpu" + +path = 'best_mem.pickle' + +with open(path, 'rb') as f: + memory = pickle.load(f) +#%% +in_path = '/Users/filippomerlo/Desktop/Datasets/SOLA' +source = 'train' +memory = memory +in_base = bn_train +types = tyimgs +dic = dic_train_logical +vocab = all_vocabs +# %% +def my_clip_evaluation(in_path, source, memory, in_base, types, dic, vocab): + + with torch.no_grad(): + # get vocab dictionary + if source == 'train': + dic = dic_test + else: + dic = dic_train + + # get dataset + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + dt = MyDataset(in_path, source, in_base, types, dic, vocab, + clip_preprocessor=clip_preprocess) + data_loader = DataLoader(dt, batch_size=132, shuffle=True) + + top3 = 0 + top3_color = 0 + top3_material = 0 + top3_shape = 0 + tot_num = 0 + i = 0 + #for base_is, images in data_loader: # labels (one hot), images (clip embs) + base_is, images = next(iter(data_loader)) + # Prepare the inputs + images = images.to(device) + ans = [] + rel = [] + batch_size_i = len(base_is) + + # go through memory + for label in vocab: # select a label es 'red' + if label not in memory.keys(): + ans.append(torch.full((batch_size_i, 1), 1000.0).squeeze(1)) + continue + + # load model + model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False) + model.load_state_dict(memory[label]['model']) # load weights corresponding to red + model.to(device) + model.eval() # freeze + + # load centroid + centroid_i = memory[label]['centroid'].to(device) + centroid_i = centroid_i.repeat(batch_size_i, 1) + + # compute stats + z = model(clip_model, images).squeeze(0) + disi = ((z - centroid_i)**2).mean(dim=1) + ans.append(disi.detach().to('cpu')) + + # get top3 indices + ans = torch.stack(ans, dim=1) + values, indices = ans.topk(3, largest=False) + _, indices_lb = base_is.topk(3) + # base_is = [00001000000010000001] + # indices_lb [5,12,19] + indices_lb, _ = torch.sort(indices_lb) + + # calculate stats + tot_num += len(indices) + for bi in range(len(indices)): + ci = 0 + mi = 0 + si = 0 + print('***',indices[bi],'***') + if indices_lb[bi][0] in indices[bi]: + print(indices_lb[bi][0]) + ci = 1 + if indices_lb[bi][1] in indices[bi]: + print(indices_lb[bi][1]) + mi = 1 + if indices_lb[bi][2] in indices[bi]: + print(indices_lb[bi][2]) + si = 1 + + top3_color += ci + top3_material += mi + top3_shape += si + if (ci == 1) and (mi == 1) and (si == 1): + top3 += 1 + + print(tot_num, top3_color/tot_num, top3_material/tot_num, + top3_shape/tot_num, top3/tot_num) + return top3/tot_num + +n = my_clip_evaluation(in_path, source, memory, in_base, types, dic, vocab) +print(n) \ No newline at end of file diff --git a/code/try.py b/code/try.py new file mode 100644 index 0000000..1f99422 --- /dev/null +++ b/code/try.py @@ -0,0 +1,40 @@ +#%% +from dataset import * +from config import * + +from torch.utils.data import DataLoader +from pprint import pprint + +import matplotlib.pyplot as plt + +in_path = '/Users/filippomerlo/Desktop/Datasets/SOLA' +source = 'train' +in_base = bn_train +types = tyimgs +dic = dic_train_logical +vocab = all_vocabs + +dt = MyDataset(in_path, source, in_base, types, dic, vocab) +data_loader = DataLoader(dt, batch_size=132, shuffle=True) + +#%% +train_labels, train_features = next(iter(data_loader)) +print(f"Feature batch shape: {train_features.size()}") +print(f"Labels batch shape: {train_labels.size()}") +pprint(train_features) + +#%% Class Methods +#dt.__len__() +a,b = dt.__getitem__(1240) +print(a) +#%% NEW Functions +# - With attributes pairing between positive and negative samples + +attr = 'color' +lesson = 'red' +names_sim, images_sim, names_dif, images_dif = dt.get_paired_batches(attr,lesson) + +for i,n in enumerate(names_sim): + print('**********',i,'**********') + print(n,'\n',names_dif[i]) + diff --git a/code/util.py b/code/util.py new file mode 100644 index 0000000..07952bf --- /dev/null +++ b/code/util.py @@ -0,0 +1,83 @@ +import torch + +from config import * +from dataset import * + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +''' Parse file names +-- Training + 'red_rubber_cylinder_shade_base_stretch_normal_scale_large_brightness_normal_view_-2_-2_2' +-- Testing + 'yellow_rubber_sponge_shade_base_stretch_normal_scale_small_brightness_dim_view_1.5_1.5_3' + + ** 0_1_2 + red_rubber_cylinder_ + ** 3_4 + shade_base_ + ** 5_6 + stretch_normal_ + ** 7_8 + scale_large_ + ** 9_10 + brightness_normal_ + ** 11_12_13_14_ + view_-2_-2_2_ + ** x.png + rgba.png + + diffn_list = [(i,j) for i, j in zip(n1, n2) if i != j] + x[3:6] = [''.join(x[3:6])] +''' + + +def pareFileNames(base_name): + # get different attr names + n1 = base_name.split('_') + # regroup shape names with '_', e.g. torus_knot + if len(n1) > 15: + n1[2:4] = ['_'.join(n1[2:4])] + # regroup view points + n1[12:15] = ['_'.join(n1[12:15])] + + nm1 = {} + nm1["color"] = n1[0] + nm1["material"] = n1[1] + nm1["shape"] = n1[2] + nm1["shade"] = n1[4] + nm1["stretch"] = n1[6] + nm1["scale"] = n1[8] + nm1["brightness"] = n1[10] + nm1["view"] = n1[12] + + return nm1 + + +def get_mse_loss(recons, images): + recons_loss = F.mse_loss(recons, images) + return recons_loss + + +def get_mse_loss_more(recons, images): + recons_loss = 0 + for i in range(images.shape[0]): + recons_loss += F.mse_loss(recons[0], images[i]) + return recons_loss/images.shape[0] + + +def get_sim_loss(z): + centroid = torch.mean(z, dim=0) + loss = 0 + for i in range(z.shape[0]): + loss += F.mse_loss(centroid, z[i]) + + return centroid, loss/z.shape[0] + + +def get_sim_not_loss(centroid, z): + loss = 0 + for i in range(z.shape[0]): + loss += F.mse_loss(centroid, z[i]) + + return loss/z.shape[0]