|
30 | 30 | },
|
31 | 31 | {
|
32 | 32 | "cell_type": "code",
|
33 |
| - "execution_count": 2, |
| 33 | + "execution_count": null, |
34 | 34 | "metadata": {},
|
35 | 35 | "outputs": [],
|
36 | 36 | "source": [
|
37 | 37 | "# Uncomment the following line to install nupic.torch\n",
|
38 |
| - "#!pip install -e git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch" |
| 38 | + "!pip install -e git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch" |
39 | 39 | ]
|
40 | 40 | },
|
41 | 41 | {
|
|
196 | 196 | "metadata": {},
|
197 | 197 | "outputs": [],
|
198 | 198 | "source": [
|
199 |
| - "from nupic.torch.modules import KWinners2d, KWinners, SparseWeights, Flatten\n", |
| 199 | + "from nupic.torch.modules import (\n", |
| 200 | + " KWinners2d, KWinners, SparseWeights, Flatten, rezeroWeights, updateBoostStrength)\n", |
200 | 201 | "\n",
|
201 | 202 | "sparseCNN = nn.Sequential(\n",
|
202 | 203 | " # Sparse CNN layer\n",
|
|
255 | 256 | "cell_type": "code",
|
256 | 257 | "execution_count": 8,
|
257 | 258 | "metadata": {},
|
| 259 | + "outputs": [], |
| 260 | + "source": [ |
| 261 | + "sgd = optim.SGD(sparseCNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", |
| 262 | + "train(model=sparseCNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)" |
| 263 | + ] |
| 264 | + }, |
| 265 | + { |
| 266 | + "cell_type": "markdown", |
| 267 | + "metadata": {}, |
| 268 | + "source": [ |
| 269 | + "After each epoch we rezero the weights to keep the initial sparsity constant during training. We also apply the boost strength factor after each epoch" |
| 270 | + ] |
| 271 | + }, |
| 272 | + { |
| 273 | + "cell_type": "code", |
| 274 | + "execution_count": 9, |
| 275 | + "metadata": {}, |
| 276 | + "outputs": [], |
| 277 | + "source": [ |
| 278 | + "%%capture\n", |
| 279 | + "sparseCNN.apply(rezeroWeights)\n", |
| 280 | + "sparseCNN.apply(updateBoostStrength)" |
| 281 | + ] |
| 282 | + }, |
| 283 | + { |
| 284 | + "cell_type": "markdown", |
| 285 | + "metadata": {}, |
| 286 | + "source": [ |
| 287 | + "Test and print results" |
| 288 | + ] |
| 289 | + }, |
| 290 | + { |
| 291 | + "cell_type": "code", |
| 292 | + "execution_count": 10, |
| 293 | + "metadata": {}, |
258 | 294 | "outputs": [
|
259 | 295 | {
|
260 |
| - "name": "stdout", |
261 |
| - "output_type": "stream", |
262 |
| - "text": [ |
263 |
| - "{'accuracy': 0.978, 'loss': 0.06757185096740723, 'total_correct': 9780}\n" |
264 |
| - ] |
| 296 | + "data": { |
| 297 | + "text/plain": [ |
| 298 | + "{'accuracy': 0.9782, 'loss': 0.06787856979370117, 'total_correct': 9782}" |
| 299 | + ] |
| 300 | + }, |
| 301 | + "execution_count": 10, |
| 302 | + "metadata": {}, |
| 303 | + "output_type": "execute_result" |
265 | 304 | }
|
266 | 305 | ],
|
267 | 306 | "source": [
|
268 |
| - "sgd = optim.SGD(sparseCNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", |
269 |
| - "train(model=sparseCNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)\n", |
270 |
| - "results = test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)\n", |
271 |
| - "print(results)" |
| 307 | + "test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)" |
272 | 308 | ]
|
273 | 309 | },
|
274 | 310 | {
|
|
280 | 316 | },
|
281 | 317 | {
|
282 | 318 | "cell_type": "code",
|
283 |
| - "execution_count": 9, |
| 319 | + "execution_count": 11, |
284 | 320 | "metadata": {},
|
285 | 321 | "outputs": [
|
286 | 322 | {
|
287 | 323 | "name": "stdout",
|
288 | 324 | "output_type": "stream",
|
289 | 325 | "text": [
|
290 |
| - "{'accuracy': 0.9856, 'loss': 0.04167602500915527, 'total_correct': 9856}\n", |
291 |
| - "{'accuracy': 0.9868, 'loss': 0.040336785697937014, 'total_correct': 9868}\n", |
292 |
| - "{'accuracy': 0.9872, 'loss': 0.03909029731750488, 'total_correct': 9872}\n", |
293 |
| - "{'accuracy': 0.9874, 'loss': 0.037309212684631346, 'total_correct': 9874}\n", |
294 |
| - "{'accuracy': 0.9876, 'loss': 0.037822017288208006, 'total_correct': 9876}\n", |
295 |
| - "{'accuracy': 0.9877, 'loss': 0.03768303909301758, 'total_correct': 9877}\n", |
296 |
| - "{'accuracy': 0.988, 'loss': 0.03783873291015625, 'total_correct': 9880}\n", |
297 |
| - "{'accuracy': 0.9881, 'loss': 0.038252718925476076, 'total_correct': 9881}\n", |
298 |
| - "{'accuracy': 0.9873, 'loss': 0.03846522216796875, 'total_correct': 9873}\n" |
| 326 | + "{'accuracy': 0.9862, 'loss': 0.0412615779876709, 'total_correct': 9862}\n", |
| 327 | + "{'accuracy': 0.9868, 'loss': 0.04029187545776367, 'total_correct': 9868}\n", |
| 328 | + "{'accuracy': 0.9867, 'loss': 0.03934368209838867, 'total_correct': 9867}\n", |
| 329 | + "{'accuracy': 0.9876, 'loss': 0.03759277114868164, 'total_correct': 9876}\n", |
| 330 | + "{'accuracy': 0.9877, 'loss': 0.03777754402160644, 'total_correct': 9877}\n", |
| 331 | + "{'accuracy': 0.9872, 'loss': 0.03784116630554199, 'total_correct': 9872}\n", |
| 332 | + "{'accuracy': 0.9872, 'loss': 0.03829168930053711, 'total_correct': 9872}\n", |
| 333 | + "{'accuracy': 0.9876, 'loss': 0.03837165260314941, 'total_correct': 9876}\n", |
| 334 | + "{'accuracy': 0.9871, 'loss': 0.03919747161865234, 'total_correct': 9871}\n" |
299 | 335 | ]
|
300 | 336 | }
|
301 | 337 | ],
|
302 | 338 | "source": [
|
303 | 339 | "for epoch in range(1, EPOCHS):\n",
|
304 | 340 | " train(model=sparseCNN, loader=train_loader, optimizer=sgd, criterion=F.nll_loss)\n",
|
| 341 | + " sparseCNN.apply(rezeroWeights)\n", |
| 342 | + " sparseCNN.apply(updateBoostStrength)\n", |
305 | 343 | " results = test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)\n",
|
306 | 344 | " print(results)"
|
307 | 345 | ]
|
|
316 | 354 | },
|
317 | 355 | {
|
318 | 356 | "cell_type": "code",
|
319 |
| - "execution_count": 10, |
| 357 | + "execution_count": 12, |
320 | 358 | "metadata": {},
|
321 | 359 | "outputs": [
|
322 | 360 | {
|
323 | 361 | "name": "stdout",
|
324 | 362 | "output_type": "stream",
|
325 | 363 | "text": [
|
326 |
| - "0.05 : {'accuracy': 0.9842, 'loss': 0.04920544128417969, 'total_correct': 9842}\n", |
327 |
| - "0.1 : {'accuracy': 0.9782, 'loss': 0.067373148727417, 'total_correct': 9782}\n", |
328 |
| - "0.15 : {'accuracy': 0.9697, 'loss': 0.0976463264465332, 'total_correct': 9697}\n", |
329 |
| - "0.2 : {'accuracy': 0.9524, 'loss': 0.15406905670166016, 'total_correct': 9524}\n", |
330 |
| - "0.25 : {'accuracy': 0.9238, 'loss': 0.23606297912597657, 'total_correct': 9238}\n" |
| 364 | + "0.05 : {'accuracy': 0.9841, 'loss': 0.05029125137329102, 'total_correct': 9841}\n", |
| 365 | + "0.1 : {'accuracy': 0.9777, 'loss': 0.06853677139282227, 'total_correct': 9777}\n", |
| 366 | + "0.15 : {'accuracy': 0.9697, 'loss': 0.09935387115478515, 'total_correct': 9697}\n", |
| 367 | + "0.2 : {'accuracy': 0.9509, 'loss': 0.1598887924194336, 'total_correct': 9509}\n", |
| 368 | + "0.25 : {'accuracy': 0.9225, 'loss': 0.2412475631713867, 'total_correct': 9225}\n" |
331 | 369 | ]
|
332 | 370 | }
|
333 | 371 | ],
|
|
359 | 397 | "name": "python",
|
360 | 398 | "nbconvert_exporter": "python",
|
361 | 399 | "pygments_lexer": "ipython3",
|
362 |
| - "version": "3.6.7" |
| 400 | + "version": "3.7.3" |
363 | 401 | }
|
364 | 402 | },
|
365 | 403 | "nbformat": 4,
|
|
0 commit comments