Skip to content

Commit

Permalink
add resnext
Browse files Browse the repository at this point in the history
  • Loading branch information
jxu7 committed May 10, 2017
1 parent 6440bed commit 4ee24b9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 31 deletions.
9 changes: 0 additions & 9 deletions planet_models/resnetxt.py

This file was deleted.

39 changes: 39 additions & 0 deletions planet_models/resnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
A re-implementation of ResNeXT. All the blocks are of bottleneck type.
The code follows the style of resnet.py in pytorch vision model.
"""
import torch
from torch.nn import *
from torch.nn import functional as F


class Bottleneck(Module):
"""Type C in the paper"""
def __init__(self, width, planes, cardinality, downsample=None, activation_fn=ELU):
super(Bottleneck, self).__init__()

def forward(self, x):
pass


class ResNeXT(Module):
def __init__(self, block, depths, num_classes, cardinality=32, activation_fn=ELU):
super(ResNeXT, self).__init__()
self.inplanes = 64
self.cardinality = cardinality
self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = BatchNorm2d(64)
self.activation = activation_fn(inplace=True)
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
self.stage1 = self._make_layers(block, self.inplanes, depths[0])
self.stage2 = self._make_layers(block, self.inplanes, depths[1], stride=2)
self.stage3 = self._make_layers(block, self.inplanes, depths[2], stride=2)
self.stage4 = self._make_layers(block, self.inplanes, depths[3], stride=2)
self.avgpool = AvgPool2d(7)
self.fc = Linear(block.expansion*512, num_classes)

def _make_layers(self, block, planes, blocks, stride=1):
pass

def forward(self, x):
pass
56 changes: 34 additions & 22 deletions trainers/resnet_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,41 @@ def train_resnet_forest(epoch=50):
ToTensor()
]
))

best_loss = np.inf
patience = 0
for i in range(epoch):
# save the model for every epoch
torch.save(resnet.state_dict(), '../models/resnet-34.pth')
# evaluating
val_loss = 0.0
f2_scores = 0.0
resnet.eval()
for batch_index, (val_x, val_y) in enumerate(validation_data_set):
if is_cuda_availible:
val_y = val_y.cuda()
val_y = Variable(val_y, volatile=True)
val_output = evaluate(resnet, val_x)
val_loss += criterion(val_output, val_y)
binary_y = threshold_labels(val_output.data.cpu().numpy())
f2 = f2_score(val_y.data.cpu().numpy(), binary_y)
f2_scores += f2
if best_loss > val_loss:
best_loss = val_loss
torch.save(resnet.state_dict(), '../models/resnet-34.pth')
else:
print('Reload previous model')
patience += 1
resnet.load_state_dict(torch.load('../models/resnet-34.pth'))

if patience >= 5:
print('Early stopping!')
break

print('Evaluation loss is {}, Training loss is {}'.format(val_loss.data[0]/batch_index, loss.data[0]))
print('F2 Score is %s' % (f2_scores/batch_index))
logger.add_record('train_loss', loss.data[0])
logger.add_record('evaluation_loss', val_loss.data[0]/batch_index)
logger.add_record('f2_score', f2_scores/batch_index)

# training
for batch_index, (target_x, target_y) in enumerate(train_data_set):
if is_cuda_availible:
target_x, target_y = target_x.cuda(), target_y.cuda()
Expand All @@ -44,25 +75,6 @@ def train_resnet_forest(epoch=50):
loss.backward()
optimizer.step()

if batch_index % 50 == 0:
val_loss = 0.0
f2_scores = 0.0
resnet.eval()
for batch_index, (val_x, val_y) in enumerate(validation_data_set):
if is_cuda_availible:
val_y = val_y.cuda()
val_y = Variable(val_y, volatile=True)
val_output = evaluate(resnet, val_x)
val_loss += criterion(val_output, val_y)
binary_y = threshold_labels(val_output.data.cpu().numpy())
f2 = f2_score(val_y.data.cpu().numpy(), binary_y)
f2_scores += f2
print('Evaluation loss is {}, Training loss is {}'.format(val_loss.data[0]/batch_index,
loss.data[0]))
print('F2 Score is %s' % (f2_scores/batch_index))
logger.add_record('train_loss', loss.data[0])
logger.add_record('evaluation_loss', val_loss.data[0]/batch_index)
logger.add_record('f2_score', f2_scores/batch_index)
print('Finished epoch {}'.format(i))
logger.save()
logger.save_plot()
Expand Down

0 comments on commit 4ee24b9

Please sign in to comment.