Skip to content

Commit

Permalink
added resnext
Browse files Browse the repository at this point in the history
  • Loading branch information
jxu7 committed May 11, 2017
1 parent 4ee24b9 commit 8d3ff25
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 22 deletions.
20 changes: 19 additions & 1 deletion planet_models/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,26 @@

class Bottleneck(Module):
"""Type C in the paper"""
def __init__(self, width, planes, cardinality, downsample=None, activation_fn=ELU):
def __init__(self, inplanes, planes, cardinality, stride, downsample=None, activation_fn=ELU):
"""
Parameters:
inplanes: # of input channels
planes: # of output channels
cardinality: # of convolution groups
stride: convolution stride
downsample: convolution operation to increase the width of the output
activation_fn: activation function
"""
super(Bottleneck, self).__init__()
depth = planes/2
self.conv1 = Conv2d(inplanes, depth, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = BatchNorm2d(planes/2)
# group convolution
self.conv2 = Conv2d(depth, depth, kernel_size=3, groups=cardinality, stride=stride, padding=1, bias=False)
self.bn2 = BatchNorm2d(planes/2)
# increase depth
self.conv3 = Conv2d(planes/2, planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = BatchNorm2d(planes)

def forward(self, x):
pass
Expand Down
6 changes: 3 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from planet_models.resnet_planet import *
from trainers.train_simplenet import evaluate

MODEL='models/resnet-101.pth'
MODEL='models/resnet-34.pth'


def test(model_dir, transform):
Expand All @@ -25,7 +25,7 @@ def test(model_dir, transform):
))

if 'resnet' in model_dir:
model = nn.DataParallel(resnet101_planet())
model = nn.DataParallel(resnet34_planet())
else:
model = MultiLabelCNN(17)
model.load_state_dict(torch.load(model_dir))
Expand All @@ -39,7 +39,7 @@ def test(model_dir, transform):
result = F.sigmoid(result)
result = result.data.cpu().numpy()
for r, id in zip(result, im_ids):
r = np.where(r >= 0.15)[0]
r = np.where(r >= 0.24)[0]
labels = [idx_to_label[index] for index in r]
imid_to_label[id] = sorted(labels)
print('Batch Index {}'.format(batch_idx))
Expand Down
5 changes: 5 additions & 0 deletions trainers/threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from util import optimize_threshold
from planet_models.resnet_planet import resnet34_planet

model = resnet34_planet()
optimize_threshold(model, '../models/resnet-34.pth')
28 changes: 14 additions & 14 deletions trainers/resnet_forest.py → trainers/train_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ def train_resnet_forest(epoch=50):
best_loss = np.inf
patience = 0
for i in range(epoch):
# training
for batch_index, (target_x, target_y) in enumerate(train_data_set):
if is_cuda_availible:
target_x, target_y = target_x.cuda(), target_y.cuda()
resnet.train()
target_x, target_y = Variable(target_x), Variable(target_y)
optimizer.zero_grad()
output = resnet(target_x)
loss = criterion(output, target_y)
loss.backward()
optimizer.step()

print('Finished epoch {}'.format(i))

# evaluating
val_loss = 0.0
f2_scores = 0.0
Expand Down Expand Up @@ -62,20 +76,6 @@ def train_resnet_forest(epoch=50):
logger.add_record('train_loss', loss.data[0])
logger.add_record('evaluation_loss', val_loss.data[0]/batch_index)
logger.add_record('f2_score', f2_scores/batch_index)

# training
for batch_index, (target_x, target_y) in enumerate(train_data_set):
if is_cuda_availible:
target_x, target_y = target_x.cuda(), target_y.cuda()
resnet.train()
target_x, target_y = Variable(target_x), Variable(target_y)
optimizer.zero_grad()
output = resnet(target_x)
loss = criterion(output, target_y)
loss.backward()
optimizer.step()

print('Finished epoch {}'.format(i))
logger.save()
logger.save_plot()

Expand Down
11 changes: 7 additions & 4 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ def optimize_threshold(model, mode_dir, resolution=10000):
"""
model = nn.DataParallel(model)
model.load_state_dict(torch.load(mode_dir))
model.cuda(0)
data = validation_jpg_loader(256, transform=input_transform(227))
model.cuda()
data = validation_jpg_loader(256, transform=Compose([
Scale(224),
ToTensor()
]))
num_class = 17
pred = []
targets = []
Expand Down Expand Up @@ -130,9 +133,9 @@ def save_plot(self):
plt.plot(np.arange(len(eval_loss)), eval_loss, color='blue', label='eval_loss')
plt.legend(loc='best')

plt.savefig('log/%s_losses.jpg' % self.name)
plt.savefig('../log/%s_losses.jpg' % self.name)

plt.figure()
plt.plot(np.arange(len(f2_scores)), f2_scores)
plt.savefig('log/%s_fcscore.jpg' % self.name)
plt.savefig('../log/%s_fcscore.jpg' % self.name)

0 comments on commit 8d3ff25

Please sign in to comment.