diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..9370c12 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/__pycache__/config.cpython-310.pyc b/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..f6f8da6 Binary files /dev/null and b/__pycache__/config.cpython-310.pyc differ diff --git a/__pycache__/dataset.cpython-310.pyc b/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000..809c365 Binary files /dev/null and b/__pycache__/dataset.cpython-310.pyc differ diff --git a/__pycache__/util.cpython-310.pyc b/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..b946e22 Binary files /dev/null and b/__pycache__/util.cpython-310.pyc differ diff --git a/code/.DS_Store b/code/.DS_Store new file mode 100644 index 0000000..13a37c6 Binary files /dev/null and b/code/.DS_Store differ diff --git a/code/__pycache__/config.cpython-310.pyc b/code/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..4c490fb Binary files /dev/null and b/code/__pycache__/config.cpython-310.pyc differ diff --git a/code/__pycache__/dataset.cpython-310.pyc b/code/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000..bede805 Binary files /dev/null and b/code/__pycache__/dataset.cpython-310.pyc differ diff --git a/code/__pycache__/models.cpython-310.pyc b/code/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000..0aa3849 Binary files /dev/null and b/code/__pycache__/models.cpython-310.pyc differ diff --git a/code/__pycache__/util.cpython-310.pyc b/code/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..a533503 Binary files /dev/null and b/code/__pycache__/util.cpython-310.pyc differ diff --git a/code/config.py b/code/config.py new file mode 100644 index 0000000..35d030c --- /dev/null +++ b/code/config.py @@ -0,0 +1,169 @@ +#%% +''' +Learning Attributes: + - Color (6) + - Material (3) + - Shape (8) +Additional Attributes: + - Color (2) + - Material (1) + - Shape (3) +Flexibility: + - Camera angle (6) + - Lighting (3) +Variability (Only in testing): + - Size (2) [Default: large] + - Stretch (3) [Default: normal] + - Color shade (2) [Default: base] + +Naming convension: +[color]_[material]_[shape]_shade_[]_stretch_[]_scale_[]_brightness_view_[]_[tyimg].png +e.g. +aqua_glass_cone_shade_base_stretch_normal_scale_large_brightness_bright_view_0_-2_3_rgba.png +''' +# Learning attributes: +colors = ['brown', "green", "blue", "aqua", "purple", "red", "yellow", 'white'] +materials = ['rubber', 'metal', 'plastic', 'glass'] +shapes = ["cube", "cylinder", "sphere", "cone", "torus", "gear", + "torus_knot", "sponge", "spot", "teapot", "suzanne"] +vocabs = colors+materials+shapes + + +# Flexibility: +views = ['0_3_2', '-2_-2_2', '-2_2_2', '1.5_-1.5_3', '1.5_1.5_3', '0_-2_3'] +brightness = ['dim', 'normal', 'bright'] + +# Variability +scale_train = ['large'] +stretch_train = ['normal'] +shade_train = ['base'] + +scale_test = ['small', 'medium', 'large'] +stretch_test = ['normal', 'x', 'y', 'z'] +shade_test = ['base', 'light', 'dark'] + +others = views + brightness + scale_test + stretch_test + shade_test + +# Types of images +tyimgs = ['rgba', 'depth', 'normal', 'object_coordinates', 'segmentation'] + + +dic_train = {"color": colors, + "material": materials, + "shape": shapes, + "view": views, + 'brightness': brightness, + "scale": scale_train, + 'stretch': stretch_train, + 'shade': shade_train + } +dic_test = {"color": colors, + "material": materials, + "shape": shapes, + "view": views, + 'brightness': brightness, + "scale": scale_test, + 'stretch': stretch_test, + 'shade': shade_test + } + +types_learning = ['color', 'material', 'shape'] +types_flebility = ['color', 'material', 'shape', 'brightness', 'view'] +types_variability = ['scale', 'stretch', 'shade'] +types_all = ['color', 'material', 'shape', 'brightness', + 'view', 'shade', 'stretch', 'scale'] + +# make dicts for logical traing and testing +relations = ['and', 'or', 'not'] # <--- new +types_logical = [] +for i in types_learning: + for j in relations: + if j == 'not': + types_logical.append(j+' '+i) + else: + for h in types_learning: + if h+' '+j+' '+i not in types_logical: + if j == 'and' and i == h: + pass + else: + types_logical.append(i+' '+j+' '+h) +types_logical_with_learning = types_logical + types_learning + +from itertools import product +from pprint import pprint + +dic_train_logical = dic_train.copy() +for rel in types_logical: + if rel.split(' ')[0] == 'not': + attr = rel.split(' ')[1] + dic_train_logical[rel] = [f'not {x}' for x in dic_train[attr]] + else: + attr1 = rel.split(' ')[0] + r = rel.split(' ')[1] + attr2 = rel.split(' ')[2] + dic_train_logical[rel] = [f'{x} {r} {y}' for x, y in product(dic_train[attr1], dic_train[attr2]) if x != y] + +dic_test_logical = dic_train_logical.copy() +dic_test_logical["scale"] = dic_test["scale"] +dic_test_logical["stretch"] = dic_test["stretch"] +dic_test_logical["shade"] = dic_test["shade"] + + + +all_vocabs = [] +for v in dic_train_logical.values(): + for n in v: + if n not in others: + all_vocabs.append(n) + +# count n of concepts + +types_logical_with_learning_1 = types_logical_with_learning[0:2] +print(types_logical_with_learning_1) +types_logical_with_learning_2 = types_logical_with_learning[2:4] +print(types_logical_with_learning_2) +types_logical_with_learning_3 = types_logical_with_learning[4:6] +print(types_logical_with_learning_3) +types_logical_with_learning_4 = types_logical_with_learning[6:8] +print(types_logical_with_learning_4) +types_logical_with_learning_5 = types_logical_with_learning[8:10] +print(types_logical_with_learning_5) +types_logical_with_learning_6 = types_logical_with_learning[10:12] +print(types_logical_with_learning_6) +types_logical_with_learning_7 = types_logical_with_learning[12:] +print(types_logical_with_learning_7) + +#print(all_vocabs) +#pprint(dic_train_logical) +#print(types_logical_with_learning) + +# paths and filenames +bn_n_train = "bn_n_train.txt" +bsn_novel_train_1 = "bsn_novel_train_1.txt" +bsn_novel_train_2 = "bsn_novel_train_2.txt" +bsn_novel_train_2_nw = "bsn_novel_train_2_nw.txt" +bsn_novel_train_2_old = "bsn_novel_train_2_old.txt" + +bn_n_test = "bn_n_test.txt" +bsn_novel_test_1 = "bsn_novel_test_1.txt" +bsn_novel_test_2_nw = "bsn_novel_test_2_nw.txt" +bsn_novel_test_2_old = "bsn_novel_test_2_old.txt" + +bn_train = "bn_train.txt" +bn_test = "bn_test.txt" +bsn_test_1 = "bsn_test_1.txt" +bsn_test_2_nw = "bsn_test_2_nw.txt" +bsn_test_2_old = "bsn_test_2_old.txt" + +# train parameters +resize = 224 +lr = 1e-3 +epochs = 1 + +sim_batch = 132 +gen_batch = 132 +batch_size = 33 + +# model architecture +hidden_dim_clip = 128 +latent_dim = 16 diff --git a/code/dataset.py b/code/dataset.py new file mode 100644 index 0000000..4f39fd8 --- /dev/null +++ b/code/dataset.py @@ -0,0 +1,274 @@ +import os +import torch +import random +from PIL import Image +import torch.nn.functional as F +import torchvision.transforms as TT +from torch.utils.data.dataset import Dataset + +from config import * +from util import * + + +class MyDataset(): + def __init__(self, in_path, source, in_base, types, + dic, vocab, clip_preprocessor=None): + self.dic = dic + self.source = source + self.types = types + self.in_path = in_path + self.totensor = TT.ToTensor() + self.resize = TT.Resize((resize, resize)) + self.clip_preprocessor = clip_preprocessor + self.dic_without_logical = {k:v for k,v in self.dic.items() if ' ' not in k} + + # convert vocab list to dic + self.vocab = vocab + self.vocab_nums = {xi: idx for idx, xi in enumerate(self.vocab)} + + # Get list of test images + self.names_list = [] + with open(os.path.join(self.in_path, 'names', in_base)) as f: + lines = f.readlines() + for line in lines: + self.names_list.append(line[:-1]) + + self.name_set = set(self.names_list) + + def __len__(self): + return len(self.names_list) + + # only for CLIP emb + def __getitem__(self, idx): + base_name = self.names_list[idx] + image = self.img_emb(base_name) + + # get label indicies + nm = pareFileNames(base_name) + num_labels = [self.vocab_nums[li] for li in [nm['color'], + nm['material'], nm['shape']]] + + # turn num_labels into one-hot + labels = torch.zeros(len(self.vocab)) + for xi in num_labels: + labels[xi] = 1 + + return labels, image + + def img_emb(self, base_name): + # get names + names = [] + for tp in self.types: + names.append(os.path.join(self.in_path, self.source, + base_name + '_' + tp + '.png')) + + # if clip preprocess + if self.clip_preprocessor is not None: + images = self.clip_preprocessor(Image.open(names[0])) + return images + + # preprocess images + images = [] + for ni in range(len(names)): + input_image = Image.open(names[ni]).convert('RGB') + input_image = self.totensor(input_image) + + if names[ni][-16:] == "segmentation.png": + input_image = input_image.sum(dim=0) + vals_seg = torch.unique(input_image) + seg_map = [] + + # generate one hot segmentation mask + for i in range(len(vals_seg)): + mask = input_image.eq(vals_seg[i]) + # hack: only keep the non-background segmentation masks + if mask[0][0] is True: + continue + seg_mapi = torch.zeros([input_image.shape[0], + input_image.shape[1]]).masked_fill_(mask, 1) + seg_map.append(seg_mapi) + + seg_map = torch.cat(seg_map).unsqueeze(0) + images.append(seg_map) + else: + images.append(input_image) + + images[ni] = self.resize(images[ni]) + + # (d, resize, resize), d = 3 + #objs (+ other img types *3) + images = torch.cat(images) + return images + + def get_paired_batches(self, attribute, lesson): + base_names_sim = [] + base_names_dif = [] + images_sim = [] + images_dif = [] + + def get_random_attribute(attribute_list, exclude=None): + attr = random.choice(attribute_list) + while attr == exclude: + attr = random.choice(attribute_list) + return attr + + def create_base_name(names_dic): + return f'{names_dic["color"]}_{names_dic["material"]}_{names_dic["shape"]}_shade_{names_dic["shade"]}_stretch_{names_dic["stretch"]}_scale_{names_dic["scale"]}_brightness_{names_dic["brightness"]}_view_{names_dic["view"]}' + + if 'and' in attribute.split() or 'or' in attribute.split(): + lesson = (lesson.split()[0], lesson.split()[2]) + ats = attribute.split() + attribute1 = ats[0] + attribute2 = ats[2] + elif 'not' in attribute.split(): + lesson = lesson.split()[1] + attribute1 = attribute.split()[1] + attribute2 = None + + if ' ' not in attribute: # if the attribute is not logical + while len(base_names_sim) < sim_batch: #133 + names_dic_sim = {} + names_dic_dif = {} + + for k, v in self.dic_without_logical.items(): # iterate on 'attribute_type':[list of attributes] + if k == attribute: # if the attribute is the one we want to teach e.g. color + names_dic_sim[k] = lesson # we take the lesson e.g. red + names_dic_dif[k] = get_random_attribute(v,lesson) + else: + tpo = get_random_attribute(v) # we take a random attribute from the list of attributes e.g. blue + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = create_base_name(names_dic_sim) # we create the name of the image from the dict + base_name_dif = create_base_name(names_dic_dif) # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + + else: # if the attribute is logical + while len(base_names_sim) < sim_batch: #133 + names_dic_sim = {} + names_dic_dif = {} + if 'and' in attribute.split(): + for negative_case in range(3): # 0,1,2 [negatives] + for k, v in self.dic_without_logical.items(): # iterate on 'attribute_type':[list of attributes] + if k == attribute1: + names_dic_sim[k] = lesson[0] + if negative_case == 0: + names_dic_dif[k] = lesson[0] + else: + names_dic_dif[k] = get_random_attribute(v,lesson[0]) + elif k==attribute2: + names_dic_sim[k] = lesson[1] + if negative_case == 1: + names_dic_dif[k] = lesson[1] + else: + names_dic_dif[k] = get_random_attribute(v,lesson[1]) + else: + tpo = get_random_attribute(v) # we take a random attribute from the list of attributes e.g. blue + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = create_base_name(names_dic_sim) # we create the name of the image from the dict + base_name_dif = create_base_name(names_dic_dif) # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + + elif 'or' in attribute.split(): + if attribute1 == attribute2: + for negative_case in range(2): # 0,1 [negatives] + for k, v in self.dic_without_logical.items(): # iterate on 'attribute_type':[list of attributes] + if k == attribute1: + tp = get_random_attribute(v) + while (tp == lesson[0] or tp == lesson[1]): + tp = get_random_attribute(v) + names_dic_dif[k] = tp + + if negative_case == 0: + names_dic_sim[k] = lesson[0] + elif negative_case == 1: + names_dic_sim[k] = lesson[1] + else: + tpo = get_random_attribute(v) # we take a random attribute from the list of attributes e.g. blue + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = create_base_name(names_dic_sim) # we create the name of the image from the dict + base_name_dif = create_base_name(names_dic_dif) # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + else: + for negative_case in range(3): # 0,1,2 [negatives] + for k, v in self.dic_without_logical.items(): # iterate on 'attribute_type':[list of attributes] + if k == attribute1: + names_dic_dif[k] = get_random_attribute(v, lesson[0]) + + if negative_case == 0 or negative_case == 1: + names_dic_sim[k] = lesson[0] + else: + names_dic_sim[k] = names_dic_dif[k] + + elif k==attribute2: + names_dic_dif[k] = get_random_attribute(v,lesson[1]) + + if negative_case == 0 or negative_case == 2: + names_dic_sim[k] = lesson[1] + else: + names_dic_sim[k] = names_dic_dif[k] + else: + tpo = get_random_attribute(v) # we take a random attribute from the list of attributes e.g. blue + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = create_base_name(names_dic_sim) # we create the name of the image from the dict + base_name_dif = create_base_name(names_dic_dif) # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + + elif 'not' in attribute.split(): + for k, v in self.dic_without_logical.items(): # iterate on 'attribute_type':[list of attributes] + if k == attribute1: # if the attribute is the one we want to teach e.g. color + names_dic_dif[k] = lesson # we take the lesson e.g. red + names_dic_sim[k] = get_random_attribute(v,lesson) + else: + tpo = get_random_attribute(v) # we take a random attribute from the list of attributes e.g. plastic + names_dic_sim[k] = tpo + names_dic_dif[k] = tpo + base_name_sim = create_base_name(names_dic_sim) # we create the name of the image from the dict + base_name_dif = create_base_name(names_dic_dif) # we create the name of the image from the dict + + if base_name_sim in self.name_set and base_name_dif in self.name_set: + base_names_sim.append(base_name_sim) + image = self.img_emb(base_name_sim) + images_sim.append(image) + + base_names_dif.append(base_name_dif) + image = self.img_emb(base_name_dif) + images_dif.append(image) + + images_sim = torch.stack(images_sim) + images_dif = torch.stack(images_dif) + + return base_names_sim, images_sim, base_names_dif, images_dif diff --git a/code/eval.py b/code/eval.py new file mode 100644 index 0000000..62cbc0e --- /dev/null +++ b/code/eval.py @@ -0,0 +1,214 @@ +#%% +import torch +import clip + +from torch.utils.data import DataLoader + +from config import * +from dataset import * +from models import * + +import pickle +import argparse + +#print(torch.backends.mps.is_available()) # the MacOS is higher than 12.3+ +#print(torch.backends.mps.is_built()) # MPS is activated +#device = torch.device('mps') +device = "cuda" if torch.cuda.is_available() else "cpu" + +def my_clip_evaluation(in_path, source, memory, in_base, types, dic, vocab): + with torch.no_grad(): + # get vocab dictionary + if source == 'train': + dic = dic_test + else: + dic = dic_test + + # get dataset + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + dt = MyDataset(in_path, source, in_base, types, dic, vocab, clip_preprocessor=clip_preprocess) + data_loader = DataLoader(dt, batch_size=10, shuffle=True) + + top3 = 0 + top3_color = 0 + top3_material = 0 + top3_shape = 0 + + top3_and = 0 + top3_color_and_material = 0 + top3_color_and_shape = 0 + top3_material_and_shape = 0 + tot_num_and = 0 + + top3_or = 0 + top3_color_or_material = 0 + top3_color_or_shape = 0 + top3_material_or_shape = 0 + top3_color_or_color = 0 + top3_material_or_material = 0 + top3_shape_or_shape = 0 + tot_num_or = 0 + + top3_not = 0 + top3_not_color = 0 + top3_not_material = 0 + top3_not_shape = 0 + tot_num_not = 0 + + tot_num = 0 + + for base_is, images in data_loader: + # Prepare the inputs + images = images.to(device) + ans = [] + batch_size_i = len(base_is) + + # go through memory + for label in vocabs: + if label not in memory.keys(): + ans.append(torch.full((batch_size_i, 1), 1000.0).squeeze(1)) + continue + + # load model + model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False) + model.load_state_dict(memory[label]['model']) + model.to(device) + model.eval() + + # load centroid + centroid_i = memory[label]['centroid'].to(device) + centroid_i = centroid_i.repeat(batch_size_i, 1) + + # compute stats + z = model(clip_model, images).squeeze(0) + disi = ((z - centroid_i) ** 2).mean(dim=1) + ans.append(disi.detach().to('cpu')) + + # get top3 incicies + ans = torch.stack(ans, dim=1) + values, indices = ans.topk(3, largest=False) + _, indices_lb = base_is.topk(3) + indices_lb, _ = torch.sort(indices_lb) + + # calculate stats + tot_num += len(indices) + for bi in range(len(indices)): + ci = 0 + mi = 0 + si = 0 + if indices_lb[bi][0] in indices[bi]: + ci = 1 + if indices_lb[bi][1] in indices[bi]: + mi = 1 + if indices_lb[bi][2] in indices[bi]: + si = 1 + + top3_color += ci + top3_material += mi + top3_shape += si + + if (ci == 1) and (mi == 1) and (si == 1): + top3 += 1 + + print(tot_num, top3_color / tot_num, top3_material / tot_num, top3_shape / tot_num, top3 / tot_num) + + rel_list = [] + ans_logical = [] + for i, label in enumerate(all_vocabs): + print(f'{i}/{len(all_vocabs)}') + if ' ' in label: + if label not in memory.keys(): + ans_logical.append(torch.full((batch_size_i, 1), 1000.0).squeeze(1)) + continue + s = label.split(' ') + if 'not' in s: + rel = s[0] + attr1 = s[1] + attr2 = None + rel_list.append([rel, vocab.index(attr1)]) + else: + rel = s[1] + attr1 = s[0] + attr2 = s[2] + rel_list.append([rel, vocab.index(attr1), vocab.index(attr2)]) + # load model + model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False) + model.load_state_dict(memory[label]['model']) + model.to(device) + model.eval() + + # load centroid + centroid_i = memory[label]['centroid'].to(device) + centroid_i = centroid_i.repeat(batch_size_i, 1) + + # compute stats + z = model(clip_model, images).squeeze(0) + disi = ((z - centroid_i) ** 2).mean(dim=1) + ans_logical.append(disi.detach().to('cpu')) + # get top3 incicies + ans_logical = torch.stack(ans_logical, dim=1) + values, indices = ans_logical.topk(3, largest=False) + + _, indices_lb = base_is.topk(3) + indices_lb, _ = torch.sort(indices_lb) + # calculate stats + tot_num += len(indices) + for bi in range(len(indices)): + rel = rel_list[bi][0] + + if rel == 'not': + attr = rel[1] + if attr not in indices[bi]: + top3_not += 1 + tot_num_not += 1 + + elif rel == 'and': + attr1 = rel[1] + attr2 = rel[2] + if attr1 in indices[bi] and attr2 in indices[bi]: + top3_and += 1 + tot_num_and += 1 + elif rel == 'or': + attr1 = rel[1] + attr2 = rel[2] + if attr1 in indices[bi] or attr2 in indices[bi]: + top3_or += 1 + tot_num_or += 1 + + tot_logical = tot_num_not + tot_num_and + tot_num_or + print('Logical, tot, not, and , or') + print(tot_logical / tot_num, top3_not / tot_num_not, top3_and / tot_num_and, top3_or / tot_num_or) + + return top3 / tot_num + + +#TESTING + + +source = 'novel_test/' +in_base = bsn_novel_test_1 +types = ['rgba'] +dic = dic_test_logical +vocab = all_vocabs + +#in_path = '/Users/filippomerlo/Desktop/Datasets/SOLA' +#memory_path = '/Users/filippomerlo/Desktop/memories/my_best_mem_1.pickle' + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument('--in_path', type=str, required=True) + argparser.add_argument('--memory_path', type=str, required=True) + args = argparser.parse_args() + + with open(args.memory_path, 'rb') as f: + memory_complete = pickle.load(f) + for i in range(2, 7): + pieces = args.memory_path.split('my_best_mem_') + new_path = pieces[0] + f'my_best_mem_{i}.pickle' + with open(new_path, 'rb') as f: + memory = pickle.load(f) + for k in memory.keys(): + memory_complete[k] = memory[k] + + t = my_clip_evaluation(in_path, source, memory_complete, in_base, types, dic, vocab) + diff --git a/code/learn_decode.py b/code/learn_decode.py new file mode 100644 index 0000000..cb4bcc7 --- /dev/null +++ b/code/learn_decode.py @@ -0,0 +1,379 @@ +from models import Decoder +from config import * +from dataset import MyDataset +from util import * + +import argparse +import pickle +import torch +from torch import nn +import os +import clip +import random as rn +rn.seed(42) +import wandb + +clip_model, preprocess = clip.load("ViT-B/32", device=device) + +dec_types_logical_with_learning = [ + 'color and material', + 'color and shape', + 'color or material', + 'color or shape', + 'not color', + 'material and shape', + 'material or shape', + 'not material', + 'not shape'] + +dec_dic_train_logical = dic_train_logical.copy() +for k in dec_types_logical_with_learning: + if len(dic_train_logical[k]) > 5: + dec_dic_train_logical[k] = rn.sample(dic_train_logical[k], 5) + +dec_types_logical_with_learning += types_learning + +def train_decoder(in_path, out_path, model_name, memory_path, + clip_model, clip_preprocessor, + source, in_base, types, dic, vocab, + epochs=3, lr=0.001, batch_size=108, latent_dim=16): + + batch_size = batch_size + with open(memory_path, 'rb') as f: + memory = pickle.load(f) + + dt = MyDataset(in_path, source, in_base, types, dic, vocab, + clip_preprocessor) + + clip_model.eval() + + for attr in dec_types_logical_with_learning: + print('learning' + ' ' + attr) + for lesson in dic[attr]: + wandb.init(project='decode '+attr+' logical', name=lesson) # Replace with your project name and run name + print(lesson) + # build decoder + dec = Decoder(latent_dim).to(device) + + # optimizer + optimizer = torch.optim.Adam(dec.parameters(), lr=lr) + # train decoder + for epoch in range(epochs): + + for round in range(0,100): + operators = ['and', 'or', 'not'] + if not any(op in lesson.split() for op in operators): + rec_name_sim, rec_img_sim, _ , _ = dt.get_paired_batches(attr, lesson, batch_size) + edit_name_sim, edit_img_sim, edit_name_diff , edit_img_diff = dt.get_paired_batches(attr, lesson, batch_size) + loss_list = list() + + for i in range(0,batch_size): + + with torch.no_grad(): + # load sample + rec_name = rec_name_sim[i] + rec_img = clip_model.encode_image(rec_img_sim[i].unsqueeze(0).to(device)) + + edit1_name = edit_name_sim[i] + edit1_img = clip_model.encode_image(edit_img_sim[i].unsqueeze(0).to(device)) + + edit2_name = edit_name_diff[i] + edit2_img = clip_model.encode_image(edit_img_diff[i].unsqueeze(0).to(device)) + + # get second attribute for edit: lesson_filter + + lesson_filter = edit2_name.split('_')[0:3] + if 'color' in attr: + lesson_filter = lesson_filter[0] + elif 'material' in attr: + lesson_filter = lesson_filter[1] + elif 'shape' in attr: + lesson_filter = lesson_filter[2] + + # load centroid + centroid_i = memory[lesson]['centroid'].float().to(device) + + # load filters + filter_i_edit = memory[lesson_filter]['model']['filter'].to(device) + + filter_i_rec = memory[lesson]['model']['filter'].to(device) + + # forward + output = dec.forward(centroid_i) + + # edit + q2p = edit2_img * (1 - filter_i_edit) + output + + # reconstruct + p2p = rec_img * (1 - filter_i_rec) + output + + q2p_loss = 1 - get_cos_sim(edit1_img.float(), q2p.float()) + q2p_loss = q2p_loss.to(device) + + p2p_loss = 1 - get_cos_sim(rec_img.float(), p2p.float()) + p2p_loss = p2p_loss.to(device) + + loss = q2p_loss + p2p_loss + loss = loss.to(device) + loss_list.append(loss) + + elif 'and' in lesson.split(): + bs = batch_size*3 + rec_name_sim, rec_img_sim, _ , _ = dt.get_paired_batches(attr, lesson, bs) + edit_name_sim, edit_img_sim, edit_name_diff , edit_img_diff = dt.get_paired_batches(attr, lesson, bs) + loss_list = list() + + for i in range(0,bs,3): + + with torch.no_grad(): + # load sample + rec_name = rec_name_sim[i] + rec_img = clip_model.encode_image(rec_img_sim[i].unsqueeze(0).to(device)) + + edit1_name = edit_name_sim[i+2] + edit1_img = clip_model.encode_image(edit_img_sim[i].unsqueeze(0).to(device)) + + edit2_name = edit_name_diff[i+2] + edit2_img = clip_model.encode_image(edit_img_diff[i+2].unsqueeze(0).to(device)) + + # get second attribute for edit: lesson_filter + + attrs = [attr.split(' ')[0], attr.split(' ')[2]] + attrs_idx = [] + for a in attrs: + attrs_idx.append(types_learning.index(a)) + instances = edit2_name.split('_')[0:3] + filt_name = instances[attrs_idx[0]]+' and '+instances[attrs_idx[1]] + + # load centroid + centroid_i = memory[lesson]['centroid'].float().to(device) + + # load filters + filter_i_edit = memory[filt_name]['model']['filter'].to(device) + + filter_i_rec = memory[lesson]['model']['filter'].to(device) + + # forward + output = dec.forward(centroid_i) + + # edit + q2p = edit2_img * (1 - filter_i_edit) + output + + # reconstruct + p2p = rec_img * (1 - filter_i_rec) + output + + q2p_loss = 1 - get_cos_sim(edit1_img.float(), q2p.float()) + q2p_loss = q2p_loss.to(device) + + p2p_loss = 1 - get_cos_sim(rec_img.float(), p2p.float()) + p2p_loss = p2p_loss.to(device) + + loss = q2p_loss + p2p_loss + loss = loss.to(device) + loss_list.append(loss) + + elif 'or' in lesson.split(): + bs = batch_size*3 + rec_name_sim, rec_img_sim, _ , _ = dt.get_paired_batches(attr, lesson, bs) + edit_name_sim, edit_img_sim, edit_name_diff , edit_img_diff = dt.get_paired_batches(attr, lesson, bs) + loss_list = list() + + for i in range(0,bs,3): + + with torch.no_grad(): + # load sample + # reconstruct + rec_name_1 = rec_name_sim[i] + rec_img_1 = clip_model.encode_image(rec_img_sim[i].unsqueeze(0).to(device)) + + rec_name_2 = rec_name_sim[i+1] + rec_img_2 = clip_model.encode_image(rec_img_sim[i+1].unsqueeze(0).to(device)) + + rec_name_3 = rec_name_sim[i+2] + rec_img_3 = clip_model.encode_image(rec_img_sim[i+2].unsqueeze(0).to(device)) + + # edit + # es + # red or cone + + # red_metal_cone + # red_metal_teapot + # green_metal_cone + + # green_metal_cube - green or cube + out => red_metal_cone + # green_metal_teapot - green or teapot + out => red_metal_teapot + # green_metal_spot - spot + out => green_metal_cone + + edit1_name_1 = edit_name_sim[i] # both + edit1_img_1 = clip_model.encode_image(edit_img_sim[i].unsqueeze(0).to(device)) + + edit1_name_2 = edit_name_sim[i+1] # first + edit1_img_2 = clip_model.encode_image(edit_img_sim[i+1].unsqueeze(0).to(device)) + + edit1_name_3 = edit_name_sim[i+2] # second + edit1_img_3 = clip_model.encode_image(edit_img_sim[i+2].unsqueeze(0).to(device)) + + ### + + edit2_name_1 = edit_name_diff[i] + edit2_img_1 = clip_model.encode_image(edit_img_diff[i].unsqueeze(0).to(device)) + + edit2_name_2 = edit_name_diff[i+1] + edit2_img_2 = clip_model.encode_image(edit_img_diff[i+1].unsqueeze(0).to(device)) + + edit2_name_3 = edit_name_diff[i+2] + edit2_img_3 = clip_model.encode_image(edit_img_diff[i+2].unsqueeze(0).to(device)) + + names_edit_diff = [edit2_name_1, edit2_name_2, edit2_name_3] + + # get second attribute for edit: lesson_filter + + attrs = [attr.split(' ')[0], attr.split(' ')[2]] + attrs_idx = [] + for a in attrs: + attrs_idx.append(types_learning.index(a)) + + filt_names = [] + for name in names_edit_diff: + instances = name.split('_')[0:3] + filt_names.append(instances[attrs_idx[0]]+' or '+instances[attrs_idx[1]]) + + + # load centroid + centroid_i = memory[lesson]['centroid'].float().to(device) + + # load filters + filter_i_edit_1 = memory[filt_names[0]]['model']['filter'].to(device) + filter_i_edit_2 = memory[filt_names[1]]['model']['filter'].to(device) + filter_i_edit_3 = memory[filt_names[2]]['model']['filter'].to(device) + + filter_i_rec = memory[lesson]['model']['filter'].to(device) + + # forward + output = dec.forward(centroid_i) + # edit + q2p_1 = edit2_img_1 * (1 - filter_i_edit_1) + output + q2p_2 = edit2_img_2 * (1 - filter_i_edit_2) + output + q2p_3 = edit2_img_3 * (1 - filter_i_edit_3) + output + + # edit to do + + # green_metal_cube - green or cube + out => red_metal_cone + # green_metal_cube - green or cube + out => red_metal_cube + # green_metal_cube - green or cube + out => green_metal_cone + + # reconstruct + p2p_1 = rec_img_1 * (1 - filter_i_rec) + output + p2p_2 = rec_img_2 * (1 - filter_i_rec) + output + p2p_3 = rec_img_3 * (1 - filter_i_rec) + output + + # 111 red_metal_cone - red or cone + out => red_metal_cone + # 110 red_metal_teapot - red or cone + out => red_metal_teapot + # 011 green_metal_cone - red or cone + out => green_metal_cone + + + q2p_loss_1 = 1 - get_cos_sim(edit1_img_1.float(), q2p_1.float()) + q2p_loss_1 = q2p_loss_1.to(device) + q2p_loss_2 = 1 - get_cos_sim(edit1_img_2.float(), q2p_2.float()) + q2p_loss_2 = q2p_loss_2.to(device) + q2p_loss_3 = 1 - get_cos_sim(edit1_img_3.float(), q2p_3.float()) + q2p_loss_3 = q2p_loss_3.to(device) + + + p2p_loss_1 = 1 - get_cos_sim(rec_img_1.float(), p2p_1.float()) + p2p_loss_1 = p2p_loss_1.to(device) + p2p_loss_2 = 1 - get_cos_sim(rec_img_2.float(), p2p_2.float()) + p2p_loss_2 = p2p_loss_2.to(device) + p2p_loss_3 = 1 - get_cos_sim(rec_img_3.float(), p2p_3.float()) + p2p_loss_3 = p2p_loss_3.to(device) + + loss = q2p_loss_1 + q2p_loss_2 + q2p_loss_3 + p2p_loss_1 + p2p_loss_2 + p2p_loss_3 + loss = loss.to(device) + loss_list.append(loss) + + elif 'not' in lesson.split(): + rec_name_sim, rec_img_sim, rec_name_diff, rec_img_diff = dt.get_paired_batches(attr, lesson, bs) + edit_name_sim, edit_img_sim, edit_name_diff , edit_img_diff = dt.get_paired_batches(attr, lesson, bs) + loss_list = list() + + for i in range(0,batch_size): + + with torch.no_grad(): + # load sample + # not red + rec1_name = rec_name_sim[i] # not red + rec1_img = clip_model.encode_image(rec_img_sim[i].unsqueeze(0).to(device)) + + rec2_name = rec_name_diff[i] # red + rec2_img = clip_model.encode_image(rec_img_diff[i].unsqueeze(0).to(device)) + + edit1_name = edit_name_sim[i] # not red + edit1_img = clip_model.encode_image(edit_img_sim[i].unsqueeze(0).to(device)) + + edit2_name = edit_name_diff[i] # red + edit2_img = clip_model.encode_image(edit_img_diff[i].unsqueeze(0).to(device)) + + # load centroid + centroid_i = memory[lesson]['centroid'].float().to(device) + + # load filters + filter_i_edit = memory[lesson.split(' ')[1]]['model']['filter'].to(device) + + filter_i_rec = memory[lesson]['model']['filter'].to(device) + + # forward + output = dec.forward(centroid_i) + + # edit + q2p = edit2_img * (1 - filter_i_edit) + output # red - red + not red ---> min: sim(red and out) + + # reconstruct + p2p = rec1_img * (1 - filter_i_rec) + output # not red - not red + not red ---> min: sim(red and out)) + + q2p_loss = get_cos_sim(edit2_img.float(), q2p.float()) + q2p_loss = q2p_loss.to(device) + + p2p_loss = get_cos_sim(rec2_img.float(), p2p.float()) + p2p_loss = p2p_loss.to(device) + + loss = q2p_loss + p2p_loss + loss = loss.to(device) + loss_list.append(loss) + + stacked_loss = torch.stack(loss_list) + mean_loss = torch.mean(stacked_loss) + wandb.log({"loss": mean_loss}) + # backward + optimizer.zero_grad() + mean_loss.backward() + optimizer.step() + + # print stats + if (round+1) % 10 == 0: + print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' + .format(epoch+1, epochs, round+1, 100, loss.item())) + + memory[lesson]['decoder'] = dec.to('cpu').state_dict() + # save decoder + with open(os.path.join(out_path, model_name), 'wb') as handle: + pickle.dump(memory, handle, protocol=pickle.HIGHEST_PROTOCOL) + wandb.finish() # Finish W&B run when training is complete + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + + argparser.add_argument('--in_path', '-i', + help='Data input path', required=True) + argparser.add_argument('--out_path', '-o', + help='Model memory output path', required=True) + argparser.add_argument('--model_name', '-n', default='best_mem_decoder_logic_small.pickle', + help='Best model memory to be saved file name', required=False) + argparser.add_argument('--memory_path', '-m', + help='Memory input path', required=True) + + args = argparser.parse_args() + + train_decoder(args.in_path, args.out_path, args.model_name, args.memory_path, clip_model, preprocess, + 'train/', bn_train, ['rgba'], dec_dic_train_logical, all_vocabs, + epochs=5, lr=0.001, batch_size=108, latent_dim=16) diff --git a/code/main.py b/code/main.py new file mode 100644 index 0000000..ceb1665 --- /dev/null +++ b/code/main.py @@ -0,0 +1,149 @@ +import os +import torch +import clip +import time +import pickle +import random +import argparse +import torch.nn as nn +import torch.optim as optim +from PIL import Image + +from torch.utils.data import DataLoader + +import torch.multiprocessing as mp +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed import init_process_group, destroy_process_group + +from config import * +from dataset import * +from models import * + +def my_train_clip_encoder(dt, memory, attr, lesson): + # get model + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False) + if lesson in memory.keys(): + print("______________ loading_____________________") + model.load_state_dict(memory[lesson]['model']) + optimizer = optim.Adam(model.parameters(), lr=lr) + model.train().to(device) + + loss_sim = None + loss_dif = None + loss = 10 + ct = 0 + centroid_sim = torch.rand(1, latent_dim).to(device) + while loss > 0.008: + ct += 1 + if ct > 5: + break + for i in range(200): + # Get Inputs: sim_batch, (sim_batch, 4, 132, 132) + names_sim, images_sim, names_dif, images_dif = dt.get_paired_batches(attr,lesson) + images_sim = images_sim.to(device) + + # run similar model + z_sim = model(clip_model, images_sim) + centroid_sim = centroid_sim.detach() + centroid_sim, loss_sim = get_sim_loss(torch.vstack((z_sim, centroid_sim))) + + # Run Difference + images_dif = images_dif.to(device) + + # run difference model + z_dif = model(clip_model, images_dif) + loss_dif = get_sim_not_loss(centroid_sim, z_dif) + + # compute loss + loss = (loss_sim)**2 + (loss_dif-1)**2 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + print('[', ct, ']', loss.detach().item(), loss_sim.detach().item(), + loss_dif.detach().item()) + + ############ save model ######### + with torch.no_grad(): + memory[lesson] = {'model': model.to('cpu').state_dict(), + 'arch': ['Filter', ['para_block1']], + 'centroid': centroid_sim.to('cpu') + } + return memory + +def my_clip_train(in_path, out_path, n_split, model_name, source, in_base, + types, dic, vocab, pre_trained_model=None): + # get data + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + dt = MyDataset(in_path, source, in_base, types, dic, vocab, + clip_preprocessor=clip_preprocess) + + # load encoder models from memory + memory = {} + if pre_trained_model is not None: + print(">>>>> loading memory >>>>>") + in_memory = os.path.join(out_path, pre_trained_model) + infile = open(in_memory, 'rb') + memory = pickle.load(infile) + infile.close() + + t_tot = 0 + + if n_split == '0': + learning_list = types_logical_with_learning + elif n_split == '1': + learning_list = types_logical_with_learning_1 + elif n_split == '2': + learning_list = types_logical_with_learning_2 + elif n_split == '3': + learning_list = types_logical_with_learning_3 + elif n_split == '4': + learning_list = types_logical_with_learning_4 + elif n_split == '5': + learning_list = types_logical_with_learning_5 + elif n_split == '6': + learning_list = types_logical_with_learning_6 + elif n_split == '7': + learning_list = types_logical_with_learning_7 + + for i in range(epochs): + for tl in learning_list: # attr + random.shuffle(dic[tl]) + for vi in dic[tl]: # lesson + print("#################### Learning: " + str(i) + " ----- " + str(vi)) + t_start = time.time() + memory = my_train_clip_encoder(dt, memory, tl, vi) + t_end = time.time() + t_dur = t_end - t_start + t_tot += t_dur + + print("Time: ", t_dur, t_tot) + with open(os.path.join(out_path, model_name+'_'+str(n_split)+'.pickle'), 'wb') as handle: + pickle.dump(memory, handle, protocol=pickle.HIGHEST_PROTOCOL) + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument('--in_path', '-i', + help='Data input path', required=True) + argparser.add_argument('--out_path', '-o', + help='Model memory output path', required=True) + argparser.add_argument('--n_split', '-s', default=0, + help='Split number', required=None) + argparser.add_argument('--model_name', '-n', default='my_best_mem', + help='Best model memory to be saved file name', required=False) + argparser.add_argument('--pre_train', '-p', default=None, + help='Pretrained model import name (saved in outpath)', required=False) + + argparser.add_argument('--gpu_idx', '-g', default=0, + help='Select gpu index', required=False) + + args = argparser.parse_args() + device = "cuda" if torch.cuda.is_available() else "cpu" + gpu_index = int(args.gpu_idx) + torch.cuda.set_device(gpu_index) + print('gpu:',gpu_index) + + my_clip_train(args.in_path, args.out_path, args.n_split, args.model_name, + 'train/', bn_train, ['rgba'], dic_train_logical, all_vocabs, args.pre_train) diff --git a/code/models.py b/code/models.py new file mode 100644 index 0000000..bff97c9 --- /dev/null +++ b/code/models.py @@ -0,0 +1,63 @@ +'''ResNet in PyTorch. +For Pre-activation ResNet, see 'preact_resnet.py'. +Credit: https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py +VAE Credit: https://github.com/AntixK/PyTorch-VAE/tree/a6896b944c918dd7030e7d795a8c13e5c6345ec7 +Contrastive Loss: https://lilianweng.github.io/posts/2021-05-31-contrastive/ +CLIP train: https://github.com/openai/CLIP/issues/83 + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import clip +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + + +from config import * +device = "cuda" if torch.cuda.is_available() else "cpu" + + +class CLIP_AE_Encode(nn.Module): + def __init__(self, hidden_dim, latent_dim, isAE=False): + super(CLIP_AE_Encode, self).__init__() + # Build Encoder + self.fc1 = nn.Linear(512, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, latent_dim) + self.relu = nn.ReLU(inplace=True) + + if isAE: + self.filter = nn.Parameter(torch.ones((512))) + else: + self.filter = nn.Parameter(torch.rand((512))) + + def forward(self, clip_model, images): + with torch.no_grad(): + emb = clip_model.encode_image(images).float() + out = emb * self.filter + out = self.relu(self.fc1(out)) + z = self.fc2(out) + + return z + +class Decoder(nn.Module): + def __init__(self, latent_dim): + super(Decoder, self).__init__() + # Build decoder + self.fc1 = nn.Linear(latent_dim, 64) + self.dropout1 = nn.Dropout(0.2) # Dropout layer with a dropout rate of 0.2 + self.fc2 = nn.Linear(64, 64) + self.dropout2 = nn.Dropout(0.2) + self.fc3 = nn.Linear(64, 96) + self.dropout3 = nn.Dropout(0.2) + self.fc4 = nn.Linear(96, 512) + self.relu = nn.ReLU(inplace=True) + + def forward(self, z): + out = self.dropout1(self.relu(self.fc1(z))) + out = self.dropout2(self.relu(self.fc2(out))) + out = self.dropout3(self.relu(self.fc3(out))) + out = self.fc4(out) + return out diff --git a/code/multi_attr_recog_eval.py b/code/multi_attr_recog_eval.py new file mode 100644 index 0000000..4ccf947 --- /dev/null +++ b/code/multi_attr_recog_eval.py @@ -0,0 +1,222 @@ +#%% +from config import * +from dataset import * +from models import * +from tqdm import tqdm + +from torch.utils.data import DataLoader +import pickle +import clip +import argparse + +from pprint import pprint + + +def my_clip_evaluation(in_path, source, memory_path, in_base, types, dic, vocab_data, vocab_eval): + + with torch.no_grad(): + with open(memory_path, 'rb') as f: + memory = pickle.load(f) + + # get dataset + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + dt = MyDataset(in_path, source, in_base, types, dic, vocab_data, + clip_preprocessor=clip_preprocess) + data_loader = DataLoader(dt, batch_size=132, shuffle=True) + + top3 = 0 + top3_color = 0 + tot_color = 0 + top3_material = 0 + tot_material = 0 + top3_shape = 0 + tot_shape = 0 + top3_and = 0 + tot_and = 0 + top3_or = 0 + tot_or = 0 + top3_not = 0 + tot_not = 0 + + top3_logic = 0 + tot_num = 0 + tot_num_logic = 0 + + #for base_is, images in data_loader: # labels (one hot), images (clip embs) + base_is, images = next(iter(data_loader)) + + # Prepare the inputs + images = images.to(device) + ans = [] + batch_size_i = len(base_is) + + # go through memory + for label in tqdm(vocab_eval, desc="Processing labels", unit="label"): # select a label es 'red' + if label not in memory.keys(): + print('nope') + ans.append(torch.full((batch_size_i, 1), 1000.0).squeeze(1)) + continue + + # load model + model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False) + model.load_state_dict(memory[label]['model']) # load weights corresponding to red + model.to(device) + model.eval() # freeze + + # load centroid + centroid_i = memory[label]['centroid'].to(device) + centroid_i = centroid_i.repeat(batch_size_i, 1) + + # compute stats + z = model(clip_model, images).squeeze(0) + disi = ((z - centroid_i)**2).mean(dim=1) # distance between z and centroid_i + ans.append(disi.detach().to('cpu')) + + ##### + + dsi = disi.detach().to('cpu') + values, indices = dsi.topk(1, largest=False) # get top indices + _, indices_lb = base_is.topk(3) + indices_lb, _ = torch.sort(indices_lb) + + for bi in indices: + color = indices_lb[bi][0] + material = indices_lb[bi][1] + shape = indices_lb[bi][2] + attrs = [color, material, shape] + + if len(label.split())<2: + tot_num += 1 + l1 = label + l1_idx = vocab_data.index(l1) + if color == l1_idx: + top3_color += 1 + if material == l1_idx: + top3_material += 1 + if shape == l1_idx: + top3_shape += 1 + + else: + tot_num_logic += 1 + if 'and' in label.split(): + tot_and += 1 + l1 = label.split()[0] + l2 = label.split()[2] + l1_idx = vocab_data.index(l1) + l2_idx = vocab_data.index(l2) + if (l1_idx in attrs) and (l2_idx in attrs): + top3_and += 1 + elif 'or' in label.split(): + tot_or += 1 + l1 = label.split()[0] + l2 = label.split()[2] + l1_idx = vocab_data.index(l1) + l2_idx = vocab_data.index(l2) + if (l1_idx in attrs) or (l2_idx in attrs): + top3_or += 1 + elif 'not' in label.split(): + tot_not += 1 + l1 = label.split()[1] + l1_idx = vocab_data.index(l1) + if l1_idx not in attrs: + top3_not += 1 + top3 = (top3_color + top3_material + top3_shape)/tot_num + top3_logic = (top3_and + top3_or + top3_not)/tot_num_logic + print('color',top3_color,'\nmaterial',top3_material,'\nshape',top3_shape,'\nclassic',top3) + print('and',top3_and/tot_and,'\nor',top3_or/tot_or,'\nnot',top3_not/tot_not,'\nlogic',top3_logic) + ##### +''' + # get top3 indices + ans = torch.stack(ans, dim=1) + values, indices = ans.topk(3, largest=False) + _, indices_lb = base_is.topk(3) + indices_lb, _ = torch.sort(indices_lb) + + # calculate stats + tot_num += len(indices) + + for bi in range(len(indices)): + ci = 0 + mi = 0 + si = 0 + print('***',indices[bi],'***') + if indices_lb[bi][0] in indices[bi]: + print(indices_lb[bi][0]) + ci = 1 + if indices_lb[bi][1] in indices[bi]: + print(indices_lb[bi][1]) + mi = 1 + if indices_lb[bi][2] in indices[bi]: + print(indices_lb[bi][2]) + si = 1 + + top3_color += ci + top3_material += mi + top3_shape += si + + if (ci == 1) and (mi == 1) and (si == 1): + top3 += 1 + + print(tot_num, top3_color/tot_num, top3_material/tot_num, + top3_shape/tot_num, top3/tot_num, + top3_and/tot_num_logic, top3_or/tot_num_logic, top3_not/tot_num_logic) + return top3/tot_num, top3_and/tot_num_logic, top3_or/tot_num_logic, top3_not/tot_num_logic +''' +source = 'train' +in_base = bn_train +types = ['rgba'] +dic = dic_train_logical +vocab_eval = all_vocabs +vocab_data = vocabs + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + + argparser.add_argument('--in_path', '-i', + help='Data input path', required=True) + + argparser.add_argument('--memory_path', '-m', + help='Memory input path', required=True) + + args = argparser.parse_args() + + #top3, top3_and, top3_or, top3_not = my_clip_evaluation(args.in_path, source, args.memory_path, in_base, types, dic, vocab_data, vocab_eval) + my_clip_evaluation(args.in_path, source, args.memory_path, in_base, types, dic, vocab_data, vocab_eval) + #print('top3:', top3) + #print('top3_and:', top3_and) + #print('top3_or:', top3_or) + #print('top3_not:', top3_not) + + +# load model +#clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) +#dt = MyDataset(in_path, source, in_base, types, dic, vocab_data, +# clip_preprocessor=clip_preprocess) +#data_loader = DataLoader(dt, batch_size=10, shuffle=True) +#base_is, images = next(iter(data_loader)) +# +#labels = ['red', 'green', 'blue', 'aqua'] +#batch_size_i = 10 +#ans = [] +#for label in labels: +# model = CLIP_AE_Encode(hidden_dim_clip, latent_dim, isAE=False) +# model.load_state_dict(memory[label]['model']) # load weights corresponding to red +# model.to(device) +# model.eval() # freeze +# +# # load centroid +# centroid_i = memory[label]['centroid'].to(device) +# centroid_i = centroid_i.repeat(batch_size_i, 1) +# +# # compute stats +# z = model(clip_model, images).squeeze(0) +# disi = ((z - centroid_i)**2).mean(dim=1) +# ans.append(disi.detach().to('cpu')) + +#ans = torch.stack(ans, dim=1) +#pprint(ans) +#values, indices = ans.topk(3, largest=False) +#pprint(indices) +#_, indices_lb = base_is.topk(3) +#indices_lb, _ = torch.sort(indices_lb) +#pprint(indices_lb) \ No newline at end of file diff --git a/code/test_decode.py b/code/test_decode.py new file mode 100644 index 0000000..785fffd --- /dev/null +++ b/code/test_decode.py @@ -0,0 +1,135 @@ +#%% +import pickle +import torch +from torch import nn +from torch.utils.data import DataLoader +import torch.nn.functional as F +import os +import clip +from models import Decoder +from config import * +from dataset import MyDataset +from util import * +from pprint import pprint +import random + +random.seed(42) + +print(torch.backends.mps.is_available()) #the MacOS is higher than 12.3+ +print(torch.backends.mps.is_built()) #MPS is activated +device = torch.device('mps') + +memory_path = '/Users/filippomerlo/Desktop/memories/best_mem_decoder_logic_small.pickle' + +with open(memory_path, 'rb') as f: + memory_base = pickle.load(f) + +def get_key_from_value(dictionary, target_value): + target = '' + for key, value in dictionary.items(): + for v in value: + if v == target_value: + target = key + return target + +in_path = '/Users/filippomerlo/Desktop/Datasets/SOLA' +source = 'train' +in_base = bn_train +types = ['rgba'] +dic = dic_train_logical +vocab = vocabs + +clip_model, clip_preprocessor = clip.load("ViT-B/32", device=device) +clip_model.eval() + +dt = MyDataset(in_path, source, in_base, types, dic, vocab, clip_preprocessor) +data_loader = DataLoader(dt, batch_size=100, shuffle=True) + +train_labels, train_features = next(iter(data_loader)) +_, idxs = train_labels.topk(3) +idxs, _ = torch.sort(idxs) + +# some operations +with torch.no_grad(): + acc = dict() + n_trials_per_attr = dict() + n_trials = 100 + for trial in range(n_trials): + # get samples for the trial + # get their one-hot encoded features + train_labels, train_features = next(iter(data_loader)) + _, idxs = train_labels.topk(3) + idxs, _ = torch.sort(idxs) + # encode the images with clip + ans = [] + for i,im in enumerate(train_features): + ans.append(clip_model.encode_image(im.unsqueeze(0).to(device)).squeeze(0)) + ans = torch.stack(ans) + + # get the answers + #for attr in types_learning: + for lesson in memory_base.keys(): + if 'decoder' in memory_base[lesson].keys(): + attr = get_key_from_value(dic, lesson) + if attr not in acc.keys(): + acc[attr] = 0 + n_trials_per_attr[attr] = 0 + n_trials_per_attr[attr] += 1 + answers = dict() + #for lesson in dic[attr]: + centroid = memory_base[lesson]['centroid'].to(device) + dec = Decoder(latent_dim).to(device) + dec.load_state_dict(memory_base[lesson]['decoder']) + decoded_rep = dec(centroid) + C = decoded_rep.repeat(ans.shape[0], 1) + disi = ((ans - C)**2).mean(dim=1).detach().to('cpu') + v, topk_idxs = disi.topk(1, largest=False) + answers[lesson] = [idxs[i] for i in topk_idxs] + + for k in answers.keys(): + for coded in answers[k]: + color = vocabs[coded[0]] + material = vocabs[coded[1]] + shape = vocabs[coded[2]] + if 'and' in k.split(): + l1 = k.split()[0] + l2 = k.split()[2] + if l1 in [color, material, shape] and l2 in [color, material, shape]: + acc[attr] += 1 + elif 'or' in k.split(): + l1 = k.split()[0] + l2 = k.split()[2] + if l1 in [color, material, shape] or l2 in [color, material, shape]: + acc[attr] += 1 + else: + if k in [color, material, shape]: + acc[attr] += 1 + +# print the results +for k in acc.keys(): + print(f'{k}: ',acc[k]/n_trials_per_attr[k]) +#%% +color_acc = acc['color']/(len(colors)*n_trials) +material_acc = acc['material']/(len(materials)*n_trials) +shape_acc = acc['shape']/(len(shapes)*n_trials) +print('Color accuracy: {}'.format(color_acc)) +print('Material accuracy: {}'.format(material_acc)) +print('Shape accuracy: {}'.format(shape_acc)) +tot_acc = (color_acc+material_acc+shape_acc)/3 +print('Total accuracy: {}'.format(tot_acc)) + +import matplotlib.pyplot as plt + +# Accuracy values +categories = ['Color', 'Material', 'Shape', 'Total'] +accuracies = [color_acc, material_acc, shape_acc, tot_acc] + +# Plotting +plt.figure(figsize=(8, 5)) +plt.bar(categories, accuracies, color=['blue', 'green', 'orange', 'red']) +plt.ylim(0, 1) # Setting y-axis limits to represent accuracy values between 0 and 1 +plt.title('Accuracy Metrics') +plt.xlabel('Categories') +plt.ylabel('Accuracy') +plt.show() + diff --git a/code/util.py b/code/util.py new file mode 100644 index 0000000..dc29da1 --- /dev/null +++ b/code/util.py @@ -0,0 +1,87 @@ +import torch + +from config import * +from dataset import * + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +''' Parse file names +-- Training + 'red_rubber_cylinder_shade_base_stretch_normal_scale_large_brightness_normal_view_-2_-2_2' +-- Testing + 'yellow_rubber_sponge_shade_base_stretch_normal_scale_small_brightness_dim_view_1.5_1.5_3' + + ** 0_1_2 + red_rubber_cylinder_ + ** 3_4 + shade_base_ + ** 5_6 + stretch_normal_ + ** 7_8 + scale_large_ + ** 9_10 + brightness_normal_ + ** 11_12_13_14_ + view_-2_-2_2_ + ** x.png + rgba.png + + diffn_list = [(i,j) for i, j in zip(n1, n2) if i != j] + x[3:6] = [''.join(x[3:6])] +''' + + +def pareFileNames(base_name): + # get different attr names + n1 = base_name.split('_') + # regroup shape names with '_', e.g. torus_knot + if len(n1) > 15: + n1[2:4] = ['_'.join(n1[2:4])] + # regroup view points + n1[12:15] = ['_'.join(n1[12:15])] + + nm1 = {} + nm1["color"] = n1[0] + nm1["material"] = n1[1] + nm1["shape"] = n1[2] + nm1["shade"] = n1[4] + nm1["stretch"] = n1[6] + nm1["scale"] = n1[8] + nm1["brightness"] = n1[10] + nm1["view"] = n1[12] + + return nm1 + + +def get_mse_loss(recons, images): + recons_loss = F.mse_loss(recons, images) + return recons_loss + + +def get_mse_loss_more(recons, images): + recons_loss = 0 + for i in range(images.shape[0]): + recons_loss += F.mse_loss(recons[0], images[i]) + return recons_loss/images.shape[0] + + +def get_sim_loss(z): + centroid = torch.mean(z, dim=0) + loss = 0 + for i in range(z.shape[0]): + loss += F.mse_loss(centroid, z[i]) + + return centroid, loss/z.shape[0] + + +def get_sim_not_loss(centroid, z): + loss = 0 + for i in range(z.shape[0]): + loss += F.mse_loss(centroid, z[i]) + + return loss/z.shape[0] + +def get_cos_sim(a,b): + return torch.nn.functional.cosine_similarity(a, b, dim=1) +