Skip to content

Commit 1ca6ea6

Browse files
committed
Add updateBoostStrength and rezeroWeights to examples
1 parent 8638ba3 commit 1ca6ea6

File tree

2 files changed

+140
-60
lines changed

2 files changed

+140
-60
lines changed

examples/sparse_cnn.ipynb

+67-29
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
},
3131
{
3232
"cell_type": "code",
33-
"execution_count": 2,
33+
"execution_count": null,
3434
"metadata": {},
3535
"outputs": [],
3636
"source": [
3737
"# Uncomment the following line to install nupic.torch\n",
38-
"#!pip install -e git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch"
38+
"!pip install -e git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch"
3939
]
4040
},
4141
{
@@ -196,7 +196,8 @@
196196
"metadata": {},
197197
"outputs": [],
198198
"source": [
199-
"from nupic.torch.modules import KWinners2d, KWinners, SparseWeights, Flatten\n",
199+
"from nupic.torch.modules import (\n",
200+
" KWinners2d, KWinners, SparseWeights, Flatten, rezeroWeights, updateBoostStrength)\n",
200201
"\n",
201202
"sparseCNN = nn.Sequential(\n",
202203
" # Sparse CNN layer\n",
@@ -255,20 +256,55 @@
255256
"cell_type": "code",
256257
"execution_count": 8,
257258
"metadata": {},
259+
"outputs": [],
260+
"source": [
261+
"sgd = optim.SGD(sparseCNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n",
262+
"train(model=sparseCNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)"
263+
]
264+
},
265+
{
266+
"cell_type": "markdown",
267+
"metadata": {},
268+
"source": [
269+
"After each epoch we rezero the weights to keep the initial sparsity constant during training. We also apply the boost strength factor after each epoch"
270+
]
271+
},
272+
{
273+
"cell_type": "code",
274+
"execution_count": 9,
275+
"metadata": {},
276+
"outputs": [],
277+
"source": [
278+
"%%capture\n",
279+
"sparseCNN.apply(rezeroWeights)\n",
280+
"sparseCNN.apply(updateBoostStrength)"
281+
]
282+
},
283+
{
284+
"cell_type": "markdown",
285+
"metadata": {},
286+
"source": [
287+
"Test and print results"
288+
]
289+
},
290+
{
291+
"cell_type": "code",
292+
"execution_count": 10,
293+
"metadata": {},
258294
"outputs": [
259295
{
260-
"name": "stdout",
261-
"output_type": "stream",
262-
"text": [
263-
"{'accuracy': 0.978, 'loss': 0.06757185096740723, 'total_correct': 9780}\n"
264-
]
296+
"data": {
297+
"text/plain": [
298+
"{'accuracy': 0.9782, 'loss': 0.06787856979370117, 'total_correct': 9782}"
299+
]
300+
},
301+
"execution_count": 10,
302+
"metadata": {},
303+
"output_type": "execute_result"
265304
}
266305
],
267306
"source": [
268-
"sgd = optim.SGD(sparseCNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n",
269-
"train(model=sparseCNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)\n",
270-
"results = test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)\n",
271-
"print(results)"
307+
"test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)"
272308
]
273309
},
274310
{
@@ -280,28 +316,30 @@
280316
},
281317
{
282318
"cell_type": "code",
283-
"execution_count": 9,
319+
"execution_count": 11,
284320
"metadata": {},
285321
"outputs": [
286322
{
287323
"name": "stdout",
288324
"output_type": "stream",
289325
"text": [
290-
"{'accuracy': 0.9856, 'loss': 0.04167602500915527, 'total_correct': 9856}\n",
291-
"{'accuracy': 0.9868, 'loss': 0.040336785697937014, 'total_correct': 9868}\n",
292-
"{'accuracy': 0.9872, 'loss': 0.03909029731750488, 'total_correct': 9872}\n",
293-
"{'accuracy': 0.9874, 'loss': 0.037309212684631346, 'total_correct': 9874}\n",
294-
"{'accuracy': 0.9876, 'loss': 0.037822017288208006, 'total_correct': 9876}\n",
295-
"{'accuracy': 0.9877, 'loss': 0.03768303909301758, 'total_correct': 9877}\n",
296-
"{'accuracy': 0.988, 'loss': 0.03783873291015625, 'total_correct': 9880}\n",
297-
"{'accuracy': 0.9881, 'loss': 0.038252718925476076, 'total_correct': 9881}\n",
298-
"{'accuracy': 0.9873, 'loss': 0.03846522216796875, 'total_correct': 9873}\n"
326+
"{'accuracy': 0.9862, 'loss': 0.0412615779876709, 'total_correct': 9862}\n",
327+
"{'accuracy': 0.9868, 'loss': 0.04029187545776367, 'total_correct': 9868}\n",
328+
"{'accuracy': 0.9867, 'loss': 0.03934368209838867, 'total_correct': 9867}\n",
329+
"{'accuracy': 0.9876, 'loss': 0.03759277114868164, 'total_correct': 9876}\n",
330+
"{'accuracy': 0.9877, 'loss': 0.03777754402160644, 'total_correct': 9877}\n",
331+
"{'accuracy': 0.9872, 'loss': 0.03784116630554199, 'total_correct': 9872}\n",
332+
"{'accuracy': 0.9872, 'loss': 0.03829168930053711, 'total_correct': 9872}\n",
333+
"{'accuracy': 0.9876, 'loss': 0.03837165260314941, 'total_correct': 9876}\n",
334+
"{'accuracy': 0.9871, 'loss': 0.03919747161865234, 'total_correct': 9871}\n"
299335
]
300336
}
301337
],
302338
"source": [
303339
"for epoch in range(1, EPOCHS):\n",
304340
" train(model=sparseCNN, loader=train_loader, optimizer=sgd, criterion=F.nll_loss)\n",
341+
" sparseCNN.apply(rezeroWeights)\n",
342+
" sparseCNN.apply(updateBoostStrength)\n",
305343
" results = test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)\n",
306344
" print(results)"
307345
]
@@ -316,18 +354,18 @@
316354
},
317355
{
318356
"cell_type": "code",
319-
"execution_count": 10,
357+
"execution_count": 12,
320358
"metadata": {},
321359
"outputs": [
322360
{
323361
"name": "stdout",
324362
"output_type": "stream",
325363
"text": [
326-
"0.05 : {'accuracy': 0.9842, 'loss': 0.04920544128417969, 'total_correct': 9842}\n",
327-
"0.1 : {'accuracy': 0.9782, 'loss': 0.067373148727417, 'total_correct': 9782}\n",
328-
"0.15 : {'accuracy': 0.9697, 'loss': 0.0976463264465332, 'total_correct': 9697}\n",
329-
"0.2 : {'accuracy': 0.9524, 'loss': 0.15406905670166016, 'total_correct': 9524}\n",
330-
"0.25 : {'accuracy': 0.9238, 'loss': 0.23606297912597657, 'total_correct': 9238}\n"
364+
"0.05 : {'accuracy': 0.9841, 'loss': 0.05029125137329102, 'total_correct': 9841}\n",
365+
"0.1 : {'accuracy': 0.9777, 'loss': 0.06853677139282227, 'total_correct': 9777}\n",
366+
"0.15 : {'accuracy': 0.9697, 'loss': 0.09935387115478515, 'total_correct': 9697}\n",
367+
"0.2 : {'accuracy': 0.9509, 'loss': 0.1598887924194336, 'total_correct': 9509}\n",
368+
"0.25 : {'accuracy': 0.9225, 'loss': 0.2412475631713867, 'total_correct': 9225}\n"
331369
]
332370
}
333371
],
@@ -359,7 +397,7 @@
359397
"name": "python",
360398
"nbconvert_exporter": "python",
361399
"pygments_lexer": "ipython3",
362-
"version": "3.6.7"
400+
"version": "3.7.3"
363401
}
364402
},
365403
"nbformat": 4,

examples/sparse_linear.ipynb

+73-31
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
},
3131
{
3232
"cell_type": "code",
33-
"execution_count": 2,
34-
"metadata": {},
33+
"execution_count": null,
34+
"metadata": {
35+
"scrolled": true
36+
},
3537
"outputs": [],
3638
"source": [
3739
"# Uncomment the following line to install nupic.torch\n",
@@ -188,7 +190,8 @@
188190
"metadata": {},
189191
"outputs": [],
190192
"source": [
191-
"from nupic.torch.modules import KWinners, SparseWeights, Flatten\n",
193+
"from nupic.torch.modules import (\n",
194+
" KWinners, SparseWeights, Flatten, rezeroWeights, updateBoostStrength)\n",
192195
"\n",
193196
"sparseNN = nn.Sequential(\n",
194197
" Flatten(),\n",
@@ -244,53 +247,90 @@
244247
"cell_type": "code",
245248
"execution_count": 8,
246249
"metadata": {},
250+
"outputs": [],
251+
"source": [
252+
"sgd = optim.SGD(sparseNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n",
253+
"train(model=sparseNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)"
254+
]
255+
},
256+
{
257+
"cell_type": "markdown",
258+
"metadata": {},
259+
"source": [
260+
"After each epoch we rezero the weights to keep the initial sparsity constant during training. We also apply the boost strength factor after each epoch"
261+
]
262+
},
263+
{
264+
"cell_type": "code",
265+
"execution_count": 9,
266+
"metadata": {},
267+
"outputs": [],
268+
"source": [
269+
"%%capture\n",
270+
"sparseNN.apply(rezeroWeights)\n",
271+
"sparseNN.apply(updateBoostStrength)"
272+
]
273+
},
274+
{
275+
"cell_type": "markdown",
276+
"metadata": {},
277+
"source": [
278+
"Test and print results"
279+
]
280+
},
281+
{
282+
"cell_type": "code",
283+
"execution_count": 10,
284+
"metadata": {},
247285
"outputs": [
248286
{
249-
"name": "stdout",
250-
"output_type": "stream",
251-
"text": [
252-
"{'accuracy': 0.9506, 'loss': 0.16315889587402344, 'total_correct': 9506}\n"
253-
]
287+
"data": {
288+
"text/plain": [
289+
"{'accuracy': 0.9486, 'loss': 0.16556934661865233, 'total_correct': 9486}"
290+
]
291+
},
292+
"execution_count": 10,
293+
"metadata": {},
294+
"output_type": "execute_result"
254295
}
255296
],
256297
"source": [
257-
"sgd = optim.SGD(sparseNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n",
258-
"train(model=sparseNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)\n",
259-
"results = test(model=sparseNN, loader=test_loader, criterion=F.nll_loss)\n",
260-
"print(results)"
298+
"test(model=sparseNN, loader=test_loader, criterion=F.nll_loss)"
261299
]
262300
},
263301
{
264302
"cell_type": "markdown",
265303
"metadata": {},
266304
"source": [
267-
"At this point the duty cycles should be stable and we can train on larger batch sizes"
305+
"At this point the duty cycles should be stable and we can train the rest of the epochs on larger batch sizes"
268306
]
269307
},
270308
{
271309
"cell_type": "code",
272-
"execution_count": 9,
310+
"execution_count": 11,
273311
"metadata": {},
274312
"outputs": [
275313
{
276314
"name": "stdout",
277315
"output_type": "stream",
278316
"text": [
279-
"{'accuracy': 0.9627, 'loss': 0.12201591415405273, 'total_correct': 9627}\n",
280-
"{'accuracy': 0.9634, 'loss': 0.12190820617675781, 'total_correct': 9634}\n",
281-
"{'accuracy': 0.9623, 'loss': 0.12345575714111329, 'total_correct': 9623}\n",
282-
"{'accuracy': 0.9639, 'loss': 0.1185587173461914, 'total_correct': 9639}\n",
283-
"{'accuracy': 0.9611, 'loss': 0.11994301071166992, 'total_correct': 9611}\n",
284-
"{'accuracy': 0.9633, 'loss': 0.11600606689453125, 'total_correct': 9633}\n",
285-
"{'accuracy': 0.9634, 'loss': 0.11699238586425781, 'total_correct': 9634}\n",
286-
"{'accuracy': 0.9639, 'loss': 0.11530724716186523, 'total_correct': 9639}\n",
287-
"{'accuracy': 0.9633, 'loss': 0.11797227020263672, 'total_correct': 9633}\n"
317+
"{'accuracy': 0.9621, 'loss': 0.12309523391723633, 'total_correct': 9621}\n",
318+
"{'accuracy': 0.9625, 'loss': 0.12451380462646484, 'total_correct': 9625}\n",
319+
"{'accuracy': 0.9621, 'loss': 0.12468773880004883, 'total_correct': 9621}\n",
320+
"{'accuracy': 0.9638, 'loss': 0.11706881561279296, 'total_correct': 9638}\n",
321+
"{'accuracy': 0.9623, 'loss': 0.120688623046875, 'total_correct': 9623}\n",
322+
"{'accuracy': 0.9645, 'loss': 0.11490174255371094, 'total_correct': 9645}\n",
323+
"{'accuracy': 0.9648, 'loss': 0.1163398452758789, 'total_correct': 9648}\n",
324+
"{'accuracy': 0.9643, 'loss': 0.1144802864074707, 'total_correct': 9643}\n",
325+
"{'accuracy': 0.9657, 'loss': 0.11591439743041992, 'total_correct': 9657}\n"
288326
]
289327
}
290328
],
291329
"source": [
292330
"for epoch in range(1, EPOCHS):\n",
293331
" train(model=sparseNN, loader=train_loader, optimizer=sgd, criterion=F.nll_loss)\n",
332+
" sparseNN.apply(updateBoostStrength)\n",
333+
" sparseNN.apply(rezeroWeights)\n",
294334
" results = test(model=sparseNN, loader=test_loader, criterion=F.nll_loss)\n",
295335
" print(results)"
296336
]
@@ -305,18 +345,20 @@
305345
},
306346
{
307347
"cell_type": "code",
308-
"execution_count": 10,
309-
"metadata": {},
348+
"execution_count": 12,
349+
"metadata": {
350+
"scrolled": true
351+
},
310352
"outputs": [
311353
{
312354
"name": "stdout",
313355
"output_type": "stream",
314356
"text": [
315-
"0.05 : {'accuracy': 0.9541, 'loss': 0.14802040100097658, 'total_correct': 9541}\n",
316-
"0.1 : {'accuracy': 0.939, 'loss': 0.1916733642578125, 'total_correct': 9390}\n",
317-
"0.15 : {'accuracy': 0.9094, 'loss': 0.2852293640136719, 'total_correct': 9094}\n",
318-
"0.2 : {'accuracy': 0.8639, 'loss': 0.4125948455810547, 'total_correct': 8639}\n",
319-
"0.25 : {'accuracy': 0.8043, 'loss': 0.5801907653808593, 'total_correct': 8043}\n"
357+
"0.05 : {'accuracy': 0.9563, 'loss': 0.14633502197265624, 'total_correct': 9563}\n",
358+
"0.1 : {'accuracy': 0.9406, 'loss': 0.19336406707763673, 'total_correct': 9406}\n",
359+
"0.15 : {'accuracy': 0.9114, 'loss': 0.28374275512695313, 'total_correct': 9114}\n",
360+
"0.2 : {'accuracy': 0.8656, 'loss': 0.4086779327392578, 'total_correct': 8656}\n",
361+
"0.25 : {'accuracy': 0.8082, 'loss': 0.5770498291015626, 'total_correct': 8082}\n"
320362
]
321363
}
322364
],
@@ -348,7 +390,7 @@
348390
"name": "python",
349391
"nbconvert_exporter": "python",
350392
"pygments_lexer": "ipython3",
351-
"version": "3.6.7"
393+
"version": "3.7.3"
352394
}
353395
},
354396
"nbformat": 4,

0 commit comments

Comments
 (0)