Skip to content

Commit

Permalink
more updates on densenet
Browse files Browse the repository at this point in the history
  • Loading branch information
ajing committed Jul 6, 2017
1 parent 8136100 commit 3fb8552
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 19 deletions.
26 changes: 15 additions & 11 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ def load_img(filepath):
np.seterr(all='warn')

if is_image_file(filepath):
image = cv2.imread(filepath)# image = io.imread(filepath) # image = Image.open(filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image = image.convert('RGB')
#image = cv2.imread(filepath)# image = io.imread(filepath) #
image = Image.open(filepath)

#image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.convert('RGB')
elif '.tif' in filepath:
tif_image = io.imread(filepath)
image = np.empty_like(tif_image).astype(np.int32)
Expand Down Expand Up @@ -168,19 +170,21 @@ def __init__(self, image_dir, label_dir=None, num_labels=17, mode='Train', input
self.images = []
suffix = '.jpg' if tif is False else '.tif'
print('[*]Loading Dataset {}'.format(image_dir))
print('[*]The current mode is {}'.format(mode))
t = time.time()
if mode == 'Train' or mode == 'Validation':
self.targets = []
self.labels = pd.read_csv(label_dir)
if read_all:
image_names = pd.read_csv('../dataset/train_all.csv')
else:
image_names = pd.read_csv(TRAIN_SPLIT if mode == 'Train' else VAL_SPLIT)
image_names = pd.read_csv(TRAIN_SPLIT if mode == 'Train' else VAL_SPLIT, header = None)
image_names = image_names.as_matrix().flatten()
print("image_names", image_names)
print("image_names length", len(image_names))
self.image_filenames = image_names
for im_name in image_names:
print("Current image:", im_name)
#print("Current image:", im_name)
str_target = self.labels.loc[self.labels['image_name'] == im_name]
image_file = os.path.join(image_dir, '{}{}'.format(im_name, suffix))
target = np.zeros(num_labels, dtype=np.float32)
Expand All @@ -189,7 +193,7 @@ def __init__(self, image_dir, label_dir=None, num_labels=17, mode='Train', input
#print(str_target['tags'].values[0].split(' '))
target_index = [label_to_idx[l] for l in str_target['tags'].values[0].split(' ')]
target[target_index] = 1
print("image_file:", image_file)
#print("image_file:", image_file)
assert(os.path.isfile(image_file))
image_obj = load_img(image_file)
#print("image_obj:",image_obj)
Expand Down Expand Up @@ -229,14 +233,14 @@ def __getitem__(self, index):
return image, im_id
else:
image = self.images[index]
print("retrieve image:", image)
print("retrieve image size:", image.size)
#print("retrieve image:", image)
#print("retrieve image size:", image.size)
target = self.targets[index]
print("retrieve target:", target)
print("current input transform function:", self.input_transform)
#print("retrieve target:", target)
#print("current input transform function:", self.input_transform)
if self.input_transform is not None:
image = self.input_transform(image)
print("image after transform:", image)
#print("image after transform:", image)
return image, torch.from_numpy(target)

def __len__(self):
Expand Down
75 changes: 69 additions & 6 deletions test_script.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {
"collapsed": true
},
Expand All @@ -14,7 +14,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {
"collapsed": true
},
Expand All @@ -25,16 +25,18 @@
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"tranform = Compose([Scale(256),ToTensor()])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -44,7 +46,7 @@
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-b5702241243e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtranform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-4-b5702241243e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtranform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/miniconda3/envs/gpu_env/lib/python3.6/site-packages/torchvision-0.1.8-py3.6.egg/torchvision/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/gpu_env/lib/python3.6/site-packages/torchvision-0.1.8-py3.6.egg/torchvision/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m \u001b[0mw\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 131\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mw\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: 'int' object is not iterable"
Expand All @@ -55,6 +57,67 @@
"tranform(img)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from PIL import Image\n",
"img = Image.open(\"/home/ubuntu/Kaggle/AmazonForest/data/train-jpg/train_40471.jpg\")\n",
"\n",
"img = img.convert('RGB')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\n",
"( 0 ,.,.) = \n",
" 0.0980 0.1098 0.1490 ... 0.1843 0.1882 0.1686\n",
" 0.1020 0.1294 0.1725 ... 0.1725 0.1765 0.1608\n",
" 0.1098 0.1490 0.1961 ... 0.1569 0.1608 0.1490\n",
" ... ⋱ ... \n",
" 0.1255 0.1255 0.1333 ... 0.1490 0.1490 0.1686\n",
" 0.0980 0.1059 0.1216 ... 0.1647 0.1608 0.1725\n",
" 0.1020 0.1098 0.1255 ... 0.1725 0.1686 0.1765\n",
"\n",
"( 1 ,.,.) = \n",
" 0.1451 0.1686 0.2078 ... 0.2353 0.2392 0.2235\n",
" 0.1294 0.1686 0.2157 ... 0.2157 0.2196 0.2078\n",
" 0.1373 0.1843 0.2353 ... 0.1922 0.2000 0.1922\n",
" ... ⋱ ... \n",
" 0.1843 0.1647 0.1686 ... 0.2235 0.2275 0.2431\n",
" 0.1569 0.1451 0.1608 ... 0.2392 0.2392 0.2510\n",
" 0.1451 0.1373 0.1569 ... 0.2471 0.2471 0.2549\n",
"\n",
"( 2 ,.,.) = \n",
" 0.1294 0.1255 0.1529 ... 0.1961 0.2078 0.2039\n",
" 0.1373 0.1490 0.1804 ... 0.1765 0.1804 0.1765\n",
" 0.1490 0.1725 0.2039 ... 0.1569 0.1529 0.1529\n",
" ... ⋱ ... \n",
" 0.1647 0.1608 0.1608 ... 0.1765 0.1765 0.1961\n",
" 0.1451 0.1451 0.1529 ... 0.1882 0.1922 0.2078\n",
" 0.1333 0.1373 0.1529 ... 0.1961 0.2000 0.2157\n",
"[torch.FloatTensor of size 3x256x256]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tranform(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
8 changes: 6 additions & 2 deletions trainers/train_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ def get_optimizer(model, pretrained=True, lr=5e-5, weight_decay=5e-5):

def train(epoch):
criterion = MultiLabelSoftMarginLoss()
net = densenet169(pretrained=False)
#net = densenet169(pretrained=False)
# edited by Jing Lu
net = densenet169(pretrained=True)
logger = Logger('../log/', NAME)
optimizer = optim.SGD(lr=1e-1, params=net.parameters(), weight_decay=5e-4, momentum=0.9, nesterov=True)
# optimizer = get_optimizer(net, False, 1e-4, 1e-4)
# optimizer = optim.Adam(params=net.parameters(), lr=5e-4, weight_decay=5e-4)
net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])
# edited by Jing Lu
#net = torch.nn.DataParallel(net, device_ids=[0, 1])
net = torch.nn.DataParallel(net)
# resnet.load_state_dict(torch.load('../models/simplenet_v3.pth'))
train_data_set = train_jpg_loader(72, transform=Compose(
[
Expand Down

0 comments on commit 3fb8552

Please sign in to comment.