Skip to content

Commit

Permalink
resize the image
Browse files Browse the repository at this point in the history
  • Loading branch information
jeasinema committed Apr 20, 2018
1 parent 83f4d9c commit 2b8411c
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 81 deletions.
4 changes: 2 additions & 2 deletions cdmp_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# File Name : cdmp_image.py
# Purpose :
# Creation Date : 19-04-2018
# Last Modified : 2018年04月19日 星期四 21时43分20秒
# Last Modified : 2018年04月20日 星期五 16时57分15秒
# Created By : Jeasine Ma [jeasinema[at]gmail[dot]com]

import torch
Expand Down Expand Up @@ -48,7 +48,7 @@ def __len__(self):


class CDMP_Image_Localization(data.Dataset):
def __init__(self, data_path, dataset_size, image_size=224, obj_size=224, obj_sum=10,
def __init__(self, data_path, dataset_size, image_size=224, obj_size=48, obj_sum=10,
collision_radius=50, border=25):
self.data = CDMP_Synthesis(data_path, obj_sum, collision_radius, border, image_size, obj_size)
self.dataset_size = dataset_size
Expand Down
9 changes: 4 additions & 5 deletions data/scripts/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@


class CDMP_Synthesis(object):
def __init__(self, data_path, obj_sum=10, collision_radius=50, border=25, image_size=224, object_size=224):
def __init__(self, data_path, obj_sum=10, collision_radius=50, border=25, image_size=224, object_size=48):
self.data_path = data_path
self.obj_sum = obj_sum
self.collision_radius = collision_radius
self.border = border
self.cls_dict, self.label = self.get_all_labels()
self.bg_h = 2400
self.bg_w = 2700
self.bg_h = 240
self.bg_w = 270
self.image_size = image_size
self.object_size = object_size
self.factor = (self.bg_w+self.bg_h)/2/image_size

def get_all_labels(self):
obj_cls = [i.split(
Expand Down Expand Up @@ -59,7 +58,7 @@ def random_place(self, obj=False, seed=0):
obj_img = []
for name in f:
im = cv2.imread(name, cv2.IMREAD_UNCHANGED)
im = cv2.resize(im, (int(im.shape[1]/self.factor), int(im.shape[0]/self.factor)))
im = cv2.resize(im, (int(im.shape[1]), int(im.shape[0])))
obj_img.append(im)

redo = True
Expand Down
Binary file modified data/scripts/output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
117 changes: 59 additions & 58 deletions main_cdmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

# dataset hyper params
min_dim = 224
object_size = 224
object_size = 48
means = (104, 117, 123)

train_loader = torch.utils.data.DataLoader(
Expand All @@ -51,25 +51,25 @@
obj_size=object_size,
),
batch_size=args.batch_size,
num_workers=32,
num_workers=1,
pin_memory=True,
shuffle=True,
collate_fn=collect_fn_image_localization
)


test_loader = torch.utils.data.DataLoader(
CDMP_Image_Localization(data_path='/home1/mxj/workspace/CDMP-localization/data',
dataset_size=1000,
image_size=min_dim,
obj_size=object_size,
),
batch_size=args.batch_size,
num_workers=32,
pin_memory=True,
shuffle=True,
collate_fn=collect_fn_image_localization
)
# test_loader = torch.utils.data.DataLoader(
# CDMP_Image_Localization(data_path='/home1/mxj/workspace/CDMP-localization/data',
# dataset_size=1000,
# image_size=min_dim,
# obj_size=object_size,
# ),
# batch_size=args.batch_size,
# num_workers=32,
# pin_memory=True,
# shuffle=True,
# collate_fn=collect_fn_image_localization
# )


def train(model, loader, epoch, optimizer):
Expand Down Expand Up @@ -110,42 +110,42 @@ def train(model, loader, epoch, optimizer):
logger.add_image('train_obj_img', obj_img, batch_idx + epoch * len(loader))


def test(model, loader, epoch):
model.eval()
test_loss = 0
for batch_idx, (img, object_img, target) in enumerate(loader):
# img: (N, C, H, W)
# object_img: (N, C, H, W)
# target: (N, 3) [x, y, id]
if args.cuda:
img, object_img, target = img.cuda(), object_img.cuda(), target.cuda()
img, object_img, target = Variable(img), Variable(object_img), Variable(target)
output = model(img, object_img)
loss = F.mse_loss(output, target[:, :2]).mean() # because you've already using log_softmax as output
test_loss += loss.data.cpu()
if batch_idx % args.log_interval == 0:
# visualize gt
y, x = np.clip((output[0]*min_dim).data.cpu().numpy().astype(np.int32), a_min=0, a_max=min_dim-1)
gt_y, gt_x, label_id = (target[0]*min_dim).data.cpu().numpy().astype(np.int32)
label_id = int(target[0, -1])
log_img = (img[0].permute(1,2,0).data.cpu().numpy()*255).astype(np.uint8)
log_img = cv2.circle(log_img, (x, y), 10, (255,0,0), 5)
cv2.putText(log_img, loader.dataset.label[label_id], (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3,
(255,0,0))
log_img = cv2.circle(log_img, (gt_x, gt_y), 10, (0,255,0), 5)
cv2.putText(log_img, loader.dataset.label[label_id], (gt_x, gt_y), cv2.FONT_HERSHEY_SIMPLEX, 0.3,
(0,255,0))
obj_img = (object_img[0].permute(1,2,0).data.cpu().numpy()*255).astype(np.uint8)
log_img[..., :] = log_img[..., [2,1,0]]
obj_img[..., :] = obj_img[..., [2,1,0]]
logger.add_image('test_log_img', log_img, batch_idx + epoch * len(loader))
logger.add_image('test_obj_img', obj_img, batch_idx + epoch * len(loader))


test_loss /= len(loader.dataset)
# visualize gt
# TBD
return test_loss
# def test(model, loader, epoch):
# model.eval()
# test_loss = 0
# for batch_idx, (img, object_img, target) in enumerate(loader):
# # img: (N, C, H, W)
# # object_img: (N, C, H, W)
# # target: (N, 3) [x, y, id]
# if args.cuda:
# img, object_img, target = img.cuda(), object_img.cuda(), target.cuda()
# img, object_img, target = Variable(img), Variable(object_img), Variable(target)
# output = model(img, object_img)
# loss = F.mse_loss(output, target[:, :2]).mean() # because you've already using log_softmax as output
# test_loss += loss.data.cpu()
# if batch_idx % args.log_interval == 0:
# # visualize gt
# y, x = np.clip((output[0]*min_dim).data.cpu().numpy().astype(np.int32), a_min=0, a_max=min_dim-1)
# gt_y, gt_x, label_id = (target[0]*min_dim).data.cpu().numpy().astype(np.int32)
# label_id = int(target[0, -1])
# log_img = (img[0].permute(1,2,0).data.cpu().numpy()*255).astype(np.uint8)
# log_img = cv2.circle(log_img, (x, y), 10, (255,0,0), 5)
# cv2.putText(log_img, loader.dataset.label[label_id], (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3,
# (255,0,0))
# log_img = cv2.circle(log_img, (gt_x, gt_y), 10, (0,255,0), 5)
# cv2.putText(log_img, loader.dataset.label[label_id], (gt_x, gt_y), cv2.FONT_HERSHEY_SIMPLEX, 0.3,
# (0,255,0))
# obj_img = (object_img[0].permute(1,2,0).data.cpu().numpy()*255).astype(np.uint8)
# log_img[..., :] = log_img[..., [2,1,0]]
# obj_img[..., :] = obj_img[..., [2,1,0]]
# logger.add_image('test_log_img', log_img, batch_idx + epoch * len(loader))
# logger.add_image('test_obj_img', obj_img, batch_idx + epoch * len(loader))
#
#
# test_loss /= len(loader.dataset)
# # visualize gt
# # TBD
# return test_loss


def main():
Expand All @@ -157,18 +157,19 @@ def main():
optimizer = optim.Adam(model.parameters(), lr=args.lr)
for epoch in range(args.epoch):
train(model, train_loader, epoch, optimizer)
loss = test(model, test_loader, epoch)
logger.add_scalar('test_loss', loss, epoch)
# loss = test(model, test_loader, epoch)
# logger.add_scalar('test_loss', loss, epoch)
if epoch % args.save_interval == 0:
torch.save(model, os.path.join(args.model_path,
args.tag + '_{}.model'.format(epoch)))
else:
model = torch.load(os.path.join(args.model_path, args.tag + '.model'))
if args.cuda:
model = nn.DataParallel(model, device_ids=device_id).cuda()
# model = model.cuda()
loss = test(model, test_loader, 0)
print('Test done, loss={}'.format(loss))
pass
# model = torch.load(os.path.join(args.model_path, args.tag + '.model'))
# if args.cuda:
# model = nn.DataParallel(model, device_ids=device_id).cuda()
# # model = model.cuda()
# loss = test(model, test_loader, 0)
# print('Test done, loss={}'.format(loss))

if __name__ == "__main__":
main()
91 changes: 75 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# File Name : model.py
# Purpose :
# Creation Date : 19-04-2018
# Last Modified : 2018年04月19日 星期四 14时22分51秒
# Last Modified : 2018年04月20日 星期五 17时07分38秒
# Created By : Jeasine Ma [jeasinema[at]gmail[dot]com]


Expand All @@ -13,31 +13,90 @@
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.models as models
from torch.nn.parameter import Parameter
import numpy as np


class SpatialSoftmax(torch.nn.Module):
def __init__(self, height, width, channel, temperature=None, data_format='NCHW'):
super(SpatialSoftmax, self).__init__()
self.data_format = data_format
self.height = height
self.width = width
self.channel = channel

if temperature:
self.temperature = Parameter(torch.ones(1)*temperature)
else:
self.temperature = 1.

pos_x, pos_y = np.meshgrid(
np.linspace(-1., 1., self.height),
np.linspace(-1., 1., self.width)
)
pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width)).float()
pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width)).float()
self.register_buffer('pos_x', pos_x)
self.register_buffer('pos_y', pos_y)

def forward(self, feature):
if self.data_format == 'NHWC':
feature = feature.permute(0,2,3,1).view(-1, self.height*self.width)
else:
feature = feature.view(-1, self.height*self.width)

softmax_attention = F.softmax(feature/self.temperature, dim=-1)
expected_x = torch.sum(Variable(self.pos_x)*softmax_attention, dim=1, keepdim=True)
expected_y = torch.sum(Variable(self.pos_y)*softmax_attention, dim=1, keepdim=True)
expected_xy = torch.cat([expected_x, expected_y], 1)
feature_keypoints = expected_xy.view(-1, self.channel*2)

return feature_keypoints


class CDMP_Localization(nn.Module):
def __init__(self, input_size, object_size):
def __init__(self, input_size, object_size, channel=3):
super(CDMP_Localization, self).__init__()
self.input_size = input_size
self.object_size = object_size
self.ch = channel

self.pool = torch.nn.MaxPool2d(2, padding=1, stride=2)
# for image input
self.conv1_img = torch.nn.Conv2d(self.ch, 64, kernel_size=4, padding=1, stride=2)
self.conv2_img = torch.nn.Conv2d(64, 64, kernel_size=4, padding=1, stride=2)
self.conv3_img = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv4_img = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv5_img = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv6_img = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.spatial_softmax = SpatialSoftmax(self.input_size // 2 // 2, self.input_size // 2 // 2, 64) # (N, 64*2)

# for object input
self.conv1_obj = torch.nn.Conv2d(self.ch, 64, kernel_size=3, padding=1)
self.conv2_obj = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv3_obj = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)

self.resnet_img = models.resnet18(pretrained=False, num_classes=256)
self.resnet_obj = models.resnet18(pretrained=False, num_classes=64)
param = torch.load('./assets/resnet18.pth')
del param['fc.weight']
del param['fc.bias']
self.resnet_img.load_state_dict(param, strict=False)
self.resnet_obj.load_state_dict(param, strict=False)
self.fc1 = nn.Linear(256+64, 256)
self.center = nn.Linear(256, 2)
# self.center = torch.nn.Linear(64*2 + 64 * (self.object_size // 2 // 2 // 2 // 2 // 2)**2 +
# 64 * (self.input_size // 2 // 2 // 2)**2, 2)
self.center = torch.nn.Linear(128+53824+3136, 2)

def forward(self, img, obj_img):
img_x = self.resnet_img(img)
obj_img_x = self.resnet_obj(obj_img)
batch_size = img.shape[0]
img_x = F.relu(self.conv1_img(img))
img_x = F.relu(self.conv2_img(img_x))
img_x = F.relu(self.conv3_img(img_x))
img_x = F.relu(self.conv4_img(img_x))
img_x = F.relu(self.conv5_img(img_x))
img_x = F.relu(self.conv6_img(img_x))
points = self.spatial_softmax(img_x)
feature = self.pool(img_x).view(batch_size, -1)

x = F.relu(self.fc1(torch.cat([img_x, obj_img_x], -1)))

return self.center(x)
obj_x = self.pool(F.relu(self.conv1_obj(obj_img)))
obj_x = self.pool(F.relu(self.conv2_obj(obj_x)))
obj_x = self.pool(F.relu(self.conv3_obj(obj_x))).view(batch_size, -1)
# print(points.shape, feature.shape, obj_x.shape)

return self.center(torch.cat([feature, obj_x, points], -1))


if __name__ == '__main__':
Expand Down

0 comments on commit 2b8411c

Please sign in to comment.