diff --git a/planet_models/resnetxt.py b/planet_models/resnetxt.py deleted file mode 100644 index 2be6e4d..0000000 --- a/planet_models/resnetxt.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from torch.nn import Module -from torch.nn import functional as F - - -class Block(Module): - def __init__(self): - super(Block, self).__init__() - diff --git a/planet_models/resnext.py b/planet_models/resnext.py new file mode 100644 index 0000000..8b35566 --- /dev/null +++ b/planet_models/resnext.py @@ -0,0 +1,39 @@ +""" + A re-implementation of ResNeXT. All the blocks are of bottleneck type. + The code follows the style of resnet.py in pytorch vision model. +""" +import torch +from torch.nn import * +from torch.nn import functional as F + + +class Bottleneck(Module): + """Type C in the paper""" + def __init__(self, width, planes, cardinality, downsample=None, activation_fn=ELU): + super(Bottleneck, self).__init__() + + def forward(self, x): + pass + + +class ResNeXT(Module): + def __init__(self, block, depths, num_classes, cardinality=32, activation_fn=ELU): + super(ResNeXT, self).__init__() + self.inplanes = 64 + self.cardinality = cardinality + self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = BatchNorm2d(64) + self.activation = activation_fn(inplace=True) + self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) + self.stage1 = self._make_layers(block, self.inplanes, depths[0]) + self.stage2 = self._make_layers(block, self.inplanes, depths[1], stride=2) + self.stage3 = self._make_layers(block, self.inplanes, depths[2], stride=2) + self.stage4 = self._make_layers(block, self.inplanes, depths[3], stride=2) + self.avgpool = AvgPool2d(7) + self.fc = Linear(block.expansion*512, num_classes) + + def _make_layers(self, block, planes, blocks, stride=1): + pass + + def forward(self, x): + pass diff --git a/trainers/resnet_forest.py b/trainers/resnet_forest.py index 07af050..5a86444 100644 --- a/trainers/resnet_forest.py +++ b/trainers/resnet_forest.py @@ -29,10 +29,41 @@ def train_resnet_forest(epoch=50): ToTensor() ] )) - + best_loss = np.inf + patience = 0 for i in range(epoch): - # save the model for every epoch - torch.save(resnet.state_dict(), '../models/resnet-34.pth') + # evaluating + val_loss = 0.0 + f2_scores = 0.0 + resnet.eval() + for batch_index, (val_x, val_y) in enumerate(validation_data_set): + if is_cuda_availible: + val_y = val_y.cuda() + val_y = Variable(val_y, volatile=True) + val_output = evaluate(resnet, val_x) + val_loss += criterion(val_output, val_y) + binary_y = threshold_labels(val_output.data.cpu().numpy()) + f2 = f2_score(val_y.data.cpu().numpy(), binary_y) + f2_scores += f2 + if best_loss > val_loss: + best_loss = val_loss + torch.save(resnet.state_dict(), '../models/resnet-34.pth') + else: + print('Reload previous model') + patience += 1 + resnet.load_state_dict(torch.load('../models/resnet-34.pth')) + + if patience >= 5: + print('Early stopping!') + break + + print('Evaluation loss is {}, Training loss is {}'.format(val_loss.data[0]/batch_index, loss.data[0])) + print('F2 Score is %s' % (f2_scores/batch_index)) + logger.add_record('train_loss', loss.data[0]) + logger.add_record('evaluation_loss', val_loss.data[0]/batch_index) + logger.add_record('f2_score', f2_scores/batch_index) + + # training for batch_index, (target_x, target_y) in enumerate(train_data_set): if is_cuda_availible: target_x, target_y = target_x.cuda(), target_y.cuda() @@ -44,25 +75,6 @@ def train_resnet_forest(epoch=50): loss.backward() optimizer.step() - if batch_index % 50 == 0: - val_loss = 0.0 - f2_scores = 0.0 - resnet.eval() - for batch_index, (val_x, val_y) in enumerate(validation_data_set): - if is_cuda_availible: - val_y = val_y.cuda() - val_y = Variable(val_y, volatile=True) - val_output = evaluate(resnet, val_x) - val_loss += criterion(val_output, val_y) - binary_y = threshold_labels(val_output.data.cpu().numpy()) - f2 = f2_score(val_y.data.cpu().numpy(), binary_y) - f2_scores += f2 - print('Evaluation loss is {}, Training loss is {}'.format(val_loss.data[0]/batch_index, - loss.data[0])) - print('F2 Score is %s' % (f2_scores/batch_index)) - logger.add_record('train_loss', loss.data[0]) - logger.add_record('evaluation_loss', val_loss.data[0]/batch_index) - logger.add_record('f2_score', f2_scores/batch_index) print('Finished epoch {}'.format(i)) logger.save() logger.save_plot()