|
32 | 32 | "cell_type": "code",
|
33 | 33 | "execution_count": 2,
|
34 | 34 | "metadata": {},
|
35 |
| - "outputs": [], |
| 35 | + "outputs": [ |
| 36 | + { |
| 37 | + "name": "stdout", |
| 38 | + "output_type": "stream", |
| 39 | + "text": [ |
| 40 | + "Obtaining nupic.torch from git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch\n", |
| 41 | + " Updating ./src/nupic.torch clone\n", |
| 42 | + " Installing build dependencies ... \u001b[?25ldone\n", |
| 43 | + "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", |
| 44 | + "\u001b[?25h Preparing wheel metadata ... \u001b[?25ldone\n", |
| 45 | + "\u001b[?25hRequirement already satisfied: torchvision>=0.2.2 in /Users/lscheinkman/miniconda3/envs/nupic.torch/lib/python3.7/site-packages/torchvision-0.2.2.post3-py3.7.egg (from nupic.torch) (0.2.2.post3)\n", |
| 46 | + "Requirement already satisfied: numpy>=1.16.2 in /Users/lscheinkman/miniconda3/envs/nupic.torch/lib/python3.7/site-packages/numpy-1.16.2-py3.7-macosx-10.7-x86_64.egg (from nupic.torch) (1.16.2)\n", |
| 47 | + "Requirement already satisfied: torch>=1.0.1 in /Users/lscheinkman/miniconda3/envs/nupic.torch/lib/python3.7/site-packages/torch-1.0.1.post2-py3.7-macosx-10.7-x86_64.egg (from nupic.torch) (1.0.1.post2)\n", |
| 48 | + "Requirement already satisfied: pillow>=4.1.1 in /Users/lscheinkman/miniconda3/envs/nupic.torch/lib/python3.7/site-packages/Pillow-6.0.0-py3.7-macosx-10.7-x86_64.egg (from torchvision>=0.2.2->nupic.torch) (6.0.0)\n", |
| 49 | + "Requirement already satisfied: six in /Users/lscheinkman/miniconda3/envs/nupic.torch/lib/python3.7/site-packages/six-1.12.0-py3.7.egg (from torchvision>=0.2.2->nupic.torch) (1.12.0)\n", |
| 50 | + "Installing collected packages: nupic.torch\n", |
| 51 | + " Found existing installation: nupic.torch 0.0.1.dev0\n", |
| 52 | + " Uninstalling nupic.torch-0.0.1.dev0:\n", |
| 53 | + " Successfully uninstalled nupic.torch-0.0.1.dev0\n", |
| 54 | + " Running setup.py develop for nupic.torch\n", |
| 55 | + "Successfully installed nupic.torch\n" |
| 56 | + ] |
| 57 | + } |
| 58 | + ], |
36 | 59 | "source": [
|
37 | 60 | "# Uncomment the following line to install nupic.torch\n",
|
38 |
| - "#!pip install -e git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch" |
| 61 | + "!pip install -e git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch" |
39 | 62 | ]
|
40 | 63 | },
|
41 | 64 | {
|
|
196 | 219 | "metadata": {},
|
197 | 220 | "outputs": [],
|
198 | 221 | "source": [
|
199 |
| - "from nupic.torch.modules import KWinners2d, KWinners, SparseWeights, Flatten\n", |
| 222 | + "from nupic.torch.modules import (\n", |
| 223 | + " KWinners2d, KWinners, SparseWeights, Flatten, rezeroWeights, updateBoostStrength)\n", |
200 | 224 | "\n",
|
201 | 225 | "sparseCNN = nn.Sequential(\n",
|
202 | 226 | " # Sparse CNN layer\n",
|
|
255 | 279 | "cell_type": "code",
|
256 | 280 | "execution_count": 8,
|
257 | 281 | "metadata": {},
|
| 282 | + "outputs": [], |
| 283 | + "source": [ |
| 284 | + "sgd = optim.SGD(sparseCNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", |
| 285 | + "train(model=sparseCNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)" |
| 286 | + ] |
| 287 | + }, |
| 288 | + { |
| 289 | + "cell_type": "markdown", |
| 290 | + "metadata": {}, |
| 291 | + "source": [ |
| 292 | + "After each epoch we rezero the weights to keep the initial sparsity constant during training. We also apply the boost strength factor after each epoch" |
| 293 | + ] |
| 294 | + }, |
| 295 | + { |
| 296 | + "cell_type": "code", |
| 297 | + "execution_count": 9, |
| 298 | + "metadata": {}, |
| 299 | + "outputs": [], |
| 300 | + "source": [ |
| 301 | + "%%capture\n", |
| 302 | + "sparseCNN.apply(rezeroWeights)\n", |
| 303 | + "sparseCNN.apply(updateBoostStrength)" |
| 304 | + ] |
| 305 | + }, |
| 306 | + { |
| 307 | + "cell_type": "markdown", |
| 308 | + "metadata": {}, |
| 309 | + "source": [ |
| 310 | + "Test and print results" |
| 311 | + ] |
| 312 | + }, |
| 313 | + { |
| 314 | + "cell_type": "code", |
| 315 | + "execution_count": 10, |
| 316 | + "metadata": {}, |
258 | 317 | "outputs": [
|
259 | 318 | {
|
260 |
| - "name": "stdout", |
261 |
| - "output_type": "stream", |
262 |
| - "text": [ |
263 |
| - "{'accuracy': 0.978, 'loss': 0.06757185096740723, 'total_correct': 9780}\n" |
264 |
| - ] |
| 319 | + "data": { |
| 320 | + "text/plain": [ |
| 321 | + "{'accuracy': 0.9782, 'loss': 0.06787856979370117, 'total_correct': 9782}" |
| 322 | + ] |
| 323 | + }, |
| 324 | + "execution_count": 10, |
| 325 | + "metadata": {}, |
| 326 | + "output_type": "execute_result" |
265 | 327 | }
|
266 | 328 | ],
|
267 | 329 | "source": [
|
268 |
| - "sgd = optim.SGD(sparseCNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", |
269 |
| - "train(model=sparseCNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)\n", |
270 |
| - "results = test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)\n", |
271 |
| - "print(results)" |
| 330 | + "test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)" |
272 | 331 | ]
|
273 | 332 | },
|
274 | 333 | {
|
|
280 | 339 | },
|
281 | 340 | {
|
282 | 341 | "cell_type": "code",
|
283 |
| - "execution_count": 9, |
| 342 | + "execution_count": 11, |
284 | 343 | "metadata": {},
|
285 | 344 | "outputs": [
|
286 | 345 | {
|
287 | 346 | "name": "stdout",
|
288 | 347 | "output_type": "stream",
|
289 | 348 | "text": [
|
290 |
| - "{'accuracy': 0.9856, 'loss': 0.04167602500915527, 'total_correct': 9856}\n", |
291 |
| - "{'accuracy': 0.9868, 'loss': 0.040336785697937014, 'total_correct': 9868}\n", |
292 |
| - "{'accuracy': 0.9872, 'loss': 0.03909029731750488, 'total_correct': 9872}\n", |
293 |
| - "{'accuracy': 0.9874, 'loss': 0.037309212684631346, 'total_correct': 9874}\n", |
294 |
| - "{'accuracy': 0.9876, 'loss': 0.037822017288208006, 'total_correct': 9876}\n", |
295 |
| - "{'accuracy': 0.9877, 'loss': 0.03768303909301758, 'total_correct': 9877}\n", |
296 |
| - "{'accuracy': 0.988, 'loss': 0.03783873291015625, 'total_correct': 9880}\n", |
297 |
| - "{'accuracy': 0.9881, 'loss': 0.038252718925476076, 'total_correct': 9881}\n", |
298 |
| - "{'accuracy': 0.9873, 'loss': 0.03846522216796875, 'total_correct': 9873}\n" |
| 349 | + "{'accuracy': 0.9862, 'loss': 0.0412615779876709, 'total_correct': 9862}\n", |
| 350 | + "{'accuracy': 0.9868, 'loss': 0.04029187545776367, 'total_correct': 9868}\n", |
| 351 | + "{'accuracy': 0.9867, 'loss': 0.03934368209838867, 'total_correct': 9867}\n", |
| 352 | + "{'accuracy': 0.9876, 'loss': 0.03759277114868164, 'total_correct': 9876}\n", |
| 353 | + "{'accuracy': 0.9877, 'loss': 0.03777754402160644, 'total_correct': 9877}\n", |
| 354 | + "{'accuracy': 0.9872, 'loss': 0.03784116630554199, 'total_correct': 9872}\n", |
| 355 | + "{'accuracy': 0.9872, 'loss': 0.03829168930053711, 'total_correct': 9872}\n", |
| 356 | + "{'accuracy': 0.9876, 'loss': 0.03837165260314941, 'total_correct': 9876}\n", |
| 357 | + "{'accuracy': 0.9871, 'loss': 0.03919747161865234, 'total_correct': 9871}\n" |
299 | 358 | ]
|
300 | 359 | }
|
301 | 360 | ],
|
302 | 361 | "source": [
|
303 | 362 | "for epoch in range(1, EPOCHS):\n",
|
304 | 363 | " train(model=sparseCNN, loader=train_loader, optimizer=sgd, criterion=F.nll_loss)\n",
|
| 364 | + " sparseCNN.apply(rezeroWeights)\n", |
| 365 | + " sparseCNN.apply(updateBoostStrength)\n", |
305 | 366 | " results = test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)\n",
|
306 | 367 | " print(results)"
|
307 | 368 | ]
|
|
316 | 377 | },
|
317 | 378 | {
|
318 | 379 | "cell_type": "code",
|
319 |
| - "execution_count": 10, |
| 380 | + "execution_count": 12, |
320 | 381 | "metadata": {},
|
321 | 382 | "outputs": [
|
322 | 383 | {
|
323 | 384 | "name": "stdout",
|
324 | 385 | "output_type": "stream",
|
325 | 386 | "text": [
|
326 |
| - "0.05 : {'accuracy': 0.9842, 'loss': 0.04920544128417969, 'total_correct': 9842}\n", |
327 |
| - "0.1 : {'accuracy': 0.9782, 'loss': 0.067373148727417, 'total_correct': 9782}\n", |
328 |
| - "0.15 : {'accuracy': 0.9697, 'loss': 0.0976463264465332, 'total_correct': 9697}\n", |
329 |
| - "0.2 : {'accuracy': 0.9524, 'loss': 0.15406905670166016, 'total_correct': 9524}\n", |
330 |
| - "0.25 : {'accuracy': 0.9238, 'loss': 0.23606297912597657, 'total_correct': 9238}\n" |
| 387 | + "0.05 : {'accuracy': 0.9841, 'loss': 0.05029125137329102, 'total_correct': 9841}\n", |
| 388 | + "0.1 : {'accuracy': 0.9777, 'loss': 0.06853677139282227, 'total_correct': 9777}\n", |
| 389 | + "0.15 : {'accuracy': 0.9697, 'loss': 0.09935387115478515, 'total_correct': 9697}\n", |
| 390 | + "0.2 : {'accuracy': 0.9509, 'loss': 0.1598887924194336, 'total_correct': 9509}\n", |
| 391 | + "0.25 : {'accuracy': 0.9225, 'loss': 0.2412475631713867, 'total_correct': 9225}\n" |
331 | 392 | ]
|
332 | 393 | }
|
333 | 394 | ],
|
|
359 | 420 | "name": "python",
|
360 | 421 | "nbconvert_exporter": "python",
|
361 | 422 | "pygments_lexer": "ipython3",
|
362 |
| - "version": "3.6.7" |
| 423 | + "version": "3.7.3" |
363 | 424 | }
|
364 | 425 | },
|
365 | 426 | "nbformat": 4,
|
|
0 commit comments