diff --git a/cdmp_image.py b/cdmp_image.py index 445756d..35283cf 100644 --- a/cdmp_image.py +++ b/cdmp_image.py @@ -4,7 +4,7 @@ # File Name : cdmp_image.py # Purpose : # Creation Date : 19-04-2018 -# Last Modified : 2018年04月19日 星期四 21时43分20秒 +# Last Modified : 2018年04月20日 星期五 16时57分15秒 # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] import torch @@ -48,7 +48,7 @@ def __len__(self): class CDMP_Image_Localization(data.Dataset): - def __init__(self, data_path, dataset_size, image_size=224, obj_size=224, obj_sum=10, + def __init__(self, data_path, dataset_size, image_size=224, obj_size=48, obj_sum=10, collision_radius=50, border=25): self.data = CDMP_Synthesis(data_path, obj_sum, collision_radius, border, image_size, obj_size) self.dataset_size = dataset_size diff --git a/data/scripts/data.py b/data/scripts/data.py index 4d09e39..94fa609 100644 --- a/data/scripts/data.py +++ b/data/scripts/data.py @@ -14,17 +14,16 @@ class CDMP_Synthesis(object): - def __init__(self, data_path, obj_sum=10, collision_radius=50, border=25, image_size=224, object_size=224): + def __init__(self, data_path, obj_sum=10, collision_radius=50, border=25, image_size=224, object_size=48): self.data_path = data_path self.obj_sum = obj_sum self.collision_radius = collision_radius self.border = border self.cls_dict, self.label = self.get_all_labels() - self.bg_h = 2400 - self.bg_w = 2700 + self.bg_h = 240 + self.bg_w = 270 self.image_size = image_size self.object_size = object_size - self.factor = (self.bg_w+self.bg_h)/2/image_size def get_all_labels(self): obj_cls = [i.split( @@ -59,7 +58,7 @@ def random_place(self, obj=False, seed=0): obj_img = [] for name in f: im = cv2.imread(name, cv2.IMREAD_UNCHANGED) - im = cv2.resize(im, (int(im.shape[1]/self.factor), int(im.shape[0]/self.factor))) + im = cv2.resize(im, (int(im.shape[1]), int(im.shape[0]))) obj_img.append(im) redo = True diff --git a/data/scripts/output.png b/data/scripts/output.png index aa7243c..58a8bc3 100644 Binary files a/data/scripts/output.png and b/data/scripts/output.png differ diff --git a/main_cdmp.py b/main_cdmp.py index 125a6c4..7b8640f 100644 --- a/main_cdmp.py +++ b/main_cdmp.py @@ -41,7 +41,7 @@ # dataset hyper params min_dim = 224 -object_size = 224 +object_size = 48 means = (104, 117, 123) train_loader = torch.utils.data.DataLoader( @@ -51,25 +51,25 @@ obj_size=object_size, ), batch_size=args.batch_size, - num_workers=32, + num_workers=1, pin_memory=True, shuffle=True, collate_fn=collect_fn_image_localization ) -test_loader = torch.utils.data.DataLoader( - CDMP_Image_Localization(data_path='/home1/mxj/workspace/CDMP-localization/data', - dataset_size=1000, - image_size=min_dim, - obj_size=object_size, - ), - batch_size=args.batch_size, - num_workers=32, - pin_memory=True, - shuffle=True, - collate_fn=collect_fn_image_localization -) +# test_loader = torch.utils.data.DataLoader( +# CDMP_Image_Localization(data_path='/home1/mxj/workspace/CDMP-localization/data', +# dataset_size=1000, +# image_size=min_dim, +# obj_size=object_size, +# ), +# batch_size=args.batch_size, +# num_workers=32, +# pin_memory=True, +# shuffle=True, +# collate_fn=collect_fn_image_localization +# ) def train(model, loader, epoch, optimizer): @@ -110,42 +110,42 @@ def train(model, loader, epoch, optimizer): logger.add_image('train_obj_img', obj_img, batch_idx + epoch * len(loader)) -def test(model, loader, epoch): - model.eval() - test_loss = 0 - for batch_idx, (img, object_img, target) in enumerate(loader): - # img: (N, C, H, W) - # object_img: (N, C, H, W) - # target: (N, 3) [x, y, id] - if args.cuda: - img, object_img, target = img.cuda(), object_img.cuda(), target.cuda() - img, object_img, target = Variable(img), Variable(object_img), Variable(target) - output = model(img, object_img) - loss = F.mse_loss(output, target[:, :2]).mean() # because you've already using log_softmax as output - test_loss += loss.data.cpu() - if batch_idx % args.log_interval == 0: - # visualize gt - y, x = np.clip((output[0]*min_dim).data.cpu().numpy().astype(np.int32), a_min=0, a_max=min_dim-1) - gt_y, gt_x, label_id = (target[0]*min_dim).data.cpu().numpy().astype(np.int32) - label_id = int(target[0, -1]) - log_img = (img[0].permute(1,2,0).data.cpu().numpy()*255).astype(np.uint8) - log_img = cv2.circle(log_img, (x, y), 10, (255,0,0), 5) - cv2.putText(log_img, loader.dataset.label[label_id], (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, - (255,0,0)) - log_img = cv2.circle(log_img, (gt_x, gt_y), 10, (0,255,0), 5) - cv2.putText(log_img, loader.dataset.label[label_id], (gt_x, gt_y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, - (0,255,0)) - obj_img = (object_img[0].permute(1,2,0).data.cpu().numpy()*255).astype(np.uint8) - log_img[..., :] = log_img[..., [2,1,0]] - obj_img[..., :] = obj_img[..., [2,1,0]] - logger.add_image('test_log_img', log_img, batch_idx + epoch * len(loader)) - logger.add_image('test_obj_img', obj_img, batch_idx + epoch * len(loader)) - - - test_loss /= len(loader.dataset) - # visualize gt - # TBD - return test_loss +# def test(model, loader, epoch): +# model.eval() +# test_loss = 0 +# for batch_idx, (img, object_img, target) in enumerate(loader): +# # img: (N, C, H, W) +# # object_img: (N, C, H, W) +# # target: (N, 3) [x, y, id] +# if args.cuda: +# img, object_img, target = img.cuda(), object_img.cuda(), target.cuda() +# img, object_img, target = Variable(img), Variable(object_img), Variable(target) +# output = model(img, object_img) +# loss = F.mse_loss(output, target[:, :2]).mean() # because you've already using log_softmax as output +# test_loss += loss.data.cpu() +# if batch_idx % args.log_interval == 0: +# # visualize gt +# y, x = np.clip((output[0]*min_dim).data.cpu().numpy().astype(np.int32), a_min=0, a_max=min_dim-1) +# gt_y, gt_x, label_id = (target[0]*min_dim).data.cpu().numpy().astype(np.int32) +# label_id = int(target[0, -1]) +# log_img = (img[0].permute(1,2,0).data.cpu().numpy()*255).astype(np.uint8) +# log_img = cv2.circle(log_img, (x, y), 10, (255,0,0), 5) +# cv2.putText(log_img, loader.dataset.label[label_id], (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, +# (255,0,0)) +# log_img = cv2.circle(log_img, (gt_x, gt_y), 10, (0,255,0), 5) +# cv2.putText(log_img, loader.dataset.label[label_id], (gt_x, gt_y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, +# (0,255,0)) +# obj_img = (object_img[0].permute(1,2,0).data.cpu().numpy()*255).astype(np.uint8) +# log_img[..., :] = log_img[..., [2,1,0]] +# obj_img[..., :] = obj_img[..., [2,1,0]] +# logger.add_image('test_log_img', log_img, batch_idx + epoch * len(loader)) +# logger.add_image('test_obj_img', obj_img, batch_idx + epoch * len(loader)) +# +# +# test_loss /= len(loader.dataset) +# # visualize gt +# # TBD +# return test_loss def main(): @@ -157,18 +157,19 @@ def main(): optimizer = optim.Adam(model.parameters(), lr=args.lr) for epoch in range(args.epoch): train(model, train_loader, epoch, optimizer) - loss = test(model, test_loader, epoch) - logger.add_scalar('test_loss', loss, epoch) + # loss = test(model, test_loader, epoch) + # logger.add_scalar('test_loss', loss, epoch) if epoch % args.save_interval == 0: torch.save(model, os.path.join(args.model_path, args.tag + '_{}.model'.format(epoch))) else: - model = torch.load(os.path.join(args.model_path, args.tag + '.model')) - if args.cuda: - model = nn.DataParallel(model, device_ids=device_id).cuda() - # model = model.cuda() - loss = test(model, test_loader, 0) - print('Test done, loss={}'.format(loss)) + pass + # model = torch.load(os.path.join(args.model_path, args.tag + '.model')) + # if args.cuda: + # model = nn.DataParallel(model, device_ids=device_id).cuda() + # # model = model.cuda() + # loss = test(model, test_loader, 0) + # print('Test done, loss={}'.format(loss)) if __name__ == "__main__": main() diff --git a/model.py b/model.py index 2d5cbf3..01b089f 100644 --- a/model.py +++ b/model.py @@ -4,7 +4,7 @@ # File Name : model.py # Purpose : # Creation Date : 19-04-2018 -# Last Modified : 2018年04月19日 星期四 14时22分51秒 +# Last Modified : 2018年04月20日 星期五 17时07分38秒 # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] @@ -13,31 +13,90 @@ import torch.nn.functional as F from torch.autograd import Variable import torchvision.models as models +from torch.nn.parameter import Parameter +import numpy as np + + +class SpatialSoftmax(torch.nn.Module): + def __init__(self, height, width, channel, temperature=None, data_format='NCHW'): + super(SpatialSoftmax, self).__init__() + self.data_format = data_format + self.height = height + self.width = width + self.channel = channel + + if temperature: + self.temperature = Parameter(torch.ones(1)*temperature) + else: + self.temperature = 1. + + pos_x, pos_y = np.meshgrid( + np.linspace(-1., 1., self.height), + np.linspace(-1., 1., self.width) + ) + pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width)).float() + pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width)).float() + self.register_buffer('pos_x', pos_x) + self.register_buffer('pos_y', pos_y) + + def forward(self, feature): + if self.data_format == 'NHWC': + feature = feature.permute(0,2,3,1).view(-1, self.height*self.width) + else: + feature = feature.view(-1, self.height*self.width) + + softmax_attention = F.softmax(feature/self.temperature, dim=-1) + expected_x = torch.sum(Variable(self.pos_x)*softmax_attention, dim=1, keepdim=True) + expected_y = torch.sum(Variable(self.pos_y)*softmax_attention, dim=1, keepdim=True) + expected_xy = torch.cat([expected_x, expected_y], 1) + feature_keypoints = expected_xy.view(-1, self.channel*2) + + return feature_keypoints class CDMP_Localization(nn.Module): - def __init__(self, input_size, object_size): + def __init__(self, input_size, object_size, channel=3): super(CDMP_Localization, self).__init__() self.input_size = input_size self.object_size = object_size + self.ch = channel + + self.pool = torch.nn.MaxPool2d(2, padding=1, stride=2) + # for image input + self.conv1_img = torch.nn.Conv2d(self.ch, 64, kernel_size=4, padding=1, stride=2) + self.conv2_img = torch.nn.Conv2d(64, 64, kernel_size=4, padding=1, stride=2) + self.conv3_img = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.conv4_img = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.conv5_img = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.conv6_img = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.spatial_softmax = SpatialSoftmax(self.input_size // 2 // 2, self.input_size // 2 // 2, 64) # (N, 64*2) + + # for object input + self.conv1_obj = torch.nn.Conv2d(self.ch, 64, kernel_size=3, padding=1) + self.conv2_obj = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.conv3_obj = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1) - self.resnet_img = models.resnet18(pretrained=False, num_classes=256) - self.resnet_obj = models.resnet18(pretrained=False, num_classes=64) - param = torch.load('./assets/resnet18.pth') - del param['fc.weight'] - del param['fc.bias'] - self.resnet_img.load_state_dict(param, strict=False) - self.resnet_obj.load_state_dict(param, strict=False) - self.fc1 = nn.Linear(256+64, 256) - self.center = nn.Linear(256, 2) + # self.center = torch.nn.Linear(64*2 + 64 * (self.object_size // 2 // 2 // 2 // 2 // 2)**2 + + # 64 * (self.input_size // 2 // 2 // 2)**2, 2) + self.center = torch.nn.Linear(128+53824+3136, 2) def forward(self, img, obj_img): - img_x = self.resnet_img(img) - obj_img_x = self.resnet_obj(obj_img) + batch_size = img.shape[0] + img_x = F.relu(self.conv1_img(img)) + img_x = F.relu(self.conv2_img(img_x)) + img_x = F.relu(self.conv3_img(img_x)) + img_x = F.relu(self.conv4_img(img_x)) + img_x = F.relu(self.conv5_img(img_x)) + img_x = F.relu(self.conv6_img(img_x)) + points = self.spatial_softmax(img_x) + feature = self.pool(img_x).view(batch_size, -1) - x = F.relu(self.fc1(torch.cat([img_x, obj_img_x], -1))) - - return self.center(x) + obj_x = self.pool(F.relu(self.conv1_obj(obj_img))) + obj_x = self.pool(F.relu(self.conv2_obj(obj_x))) + obj_x = self.pool(F.relu(self.conv3_obj(obj_x))).view(batch_size, -1) + # print(points.shape, feature.shape, obj_x.shape) + + return self.center(torch.cat([feature, obj_x, points], -1)) if __name__ == '__main__':