Skip to content

Commit

Permalink
修正:一个严重的错误
Browse files Browse the repository at this point in the history
  • Loading branch information
zergtant committed Feb 26, 2019
1 parent 6fff3a7 commit 8b84556
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions chapter4/4.2.2-tensorboardx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 0 [14848/60000 (25%)]\tLoss: 0.405838\n",
"Train Epoch: 0 [30208/60000 (50%)]\tLoss: 0.206041\n",
"Train Epoch: 0 [45568/60000 (75%)]\tLoss: 0.144166\n"
"Train Epoch: 0 [14848/60000 (25%)]\tLoss: 0.271775\n",
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n",
"Train Epoch: 0 [30208/60000 (50%)]\tLoss: 0.175213\n",
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n",
"Train Epoch: 0 [45568/60000 (75%)]\tLoss: 0.115128\n",
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n"
]
}
],
Expand Down Expand Up @@ -382,7 +385,7 @@
}
],
"source": [
"vgg16 = models.vgg16() # 这里下载预训练好的模型\n",
"vgg16 = models.vgg16(pretrained=True) # 这里下载预训练好的模型\n",
"print(vgg16) # 打印一下这个模型"
]
},
Expand All @@ -401,14 +404,10 @@
"source": [
"transform_2 = transforms.Compose([\n",
" transforms.Resize(224), \n",
" transforms.CenterCrop(224),\n",
" transforms.CenterCrop((224,224)),\n",
" transforms.ToTensor(),\n",
" # convert RGB to BGR\n",
" # from <https://github.com/mrzhu-cool/pix2pix-pytorch/blob/master/util.py>\n",
" transforms.Lambda(lambda x: torch.index_select(x, 0, torch.LongTensor([2, 1, 0]))),\n",
" transforms.Lambda(lambda x: x*255),\n",
" transforms.Normalize(mean = [103.939, 116.779, 123.68],\n",
" std = [ 1, 1, 1 ]),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225])\n",
"])"
]
},
Expand Down Expand Up @@ -453,17 +452,21 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 1000) 931\n"
]
"data": {
"text/plain": [
"287"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"raw_score = vgg16(vgg16_input)\n",
"raw_score_numpy = raw_score.data.numpy()\n",
"print(raw_score_numpy.shape, np.argmax(raw_score_numpy.ravel()))"
"out = vgg16(vgg16_input)\n",
"_, preds = torch.max(out.data, 1)\n",
"label=preds.numpy()[0]\n",
"label"
]
},
{
Expand All @@ -487,7 +490,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"打开tensorboard找到graphs 看看效果吧"
"打开tensorboard找到graphs就可以看到vgg模型具体的架构了"
]
},
{
Expand Down

0 comments on commit 8b84556

Please sign in to comment.