Skip to content

Commit

Permalink
try to run benchmark model
Browse files Browse the repository at this point in the history
  • Loading branch information
ajing committed Jul 6, 2017
1 parent 7e67f67 commit 2239eea
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
3 changes: 2 additions & 1 deletion data/kgdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def __init__(self, split, transform=None, height=256, width=256, label_csv='trai
images = np.zeros((num,height,width,3),dtype=np.float32)
for n in range(num):
img_file = data_dir + '/{}/'.format(ext) + names[n]
print(img_file)
#print(img_file)
jpg_file = img_file.replace('<ext>','jpg')
#print(jpg_file)
image_jpg = cv2.imread(jpg_file,1)
h,w = image_jpg.shape[0:2]
if height!=h or width!=w:
Expand Down
8 changes: 6 additions & 2 deletions split_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pds

def split_train_validation(num_val=3000):
def split_train_validation2(num_val=3000):
"""
Save train image names and validation image names to csv files
"""
Expand Down Expand Up @@ -32,4 +32,8 @@ def split_train_validation(num_val=3000):
df.to_csv('dataset/validation-%s' % num_val, index=False, header=False)


split_train_validation(num_val=3000)
#split_train_validation2(num_val=3000)

from util import split_train_validation

split_train_validation(num_val = 3000)
11 changes: 7 additions & 4 deletions trainers/baseline_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import matplotlib
matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! Not require X-server to be running

import torch.nn as nn
from torch.nn import functional as F
from torch import optim
Expand Down Expand Up @@ -46,8 +49,8 @@
]
batch_size = [
# 128, 128,
64, 64,
40, 40,
16, 16,
16, 16,
# 50
]

Expand Down Expand Up @@ -116,7 +119,7 @@ def load_net(net, name):

def train_baselines():

train_data, val_data = get_dataloader(96)
train_data, val_data = get_dataloader(32)

for model, batch in zip(models, batch_size):
name = str(model).split()[1]
Expand All @@ -127,7 +130,7 @@ def train_baselines():
# load pre-trained model on train-37479
net = model(pretrained=True)
net = nn.DataParallel(net.cuda())
load_net(net, name)
#load_net(net, name)
# optimizer = get_optimizer(net, lr=.001, pretrained=True, resnet=True if 'resnet' in name else False)
optimizer = optim.SGD(lr=.005, momentum=0.9, params=net.parameters(), weight_decay=5e-4)
train_data.batch_size = batch
Expand Down
2 changes: 2 additions & 0 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from torch.autograd import Variable
from sklearn.metrics import fbeta_score
from torch.nn import functional as F
import matplotlib
matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! Not require X-server to be running
from matplotlib import pyplot as plt
import pandas as pds
from datasets import *
Expand Down

0 comments on commit 2239eea

Please sign in to comment.