diff --git a/examples/sparse_cnn.ipynb b/examples/sparse_cnn.ipynb index 408ffa4..c1f68c5 100644 --- a/examples/sparse_cnn.ipynb +++ b/examples/sparse_cnn.ipynb @@ -30,12 +30,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Uncomment the following line to install nupic.torch\n", - "#!pip install -e git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch" + "!pip install -e git+https://github.com/numenta/nupic.torch.git#egg=nupic.torch" ] }, { @@ -196,7 +196,8 @@ "metadata": {}, "outputs": [], "source": [ - "from nupic.torch.modules import KWinners2d, KWinners, SparseWeights, Flatten\n", + "from nupic.torch.modules import (\n", + " KWinners2d, KWinners, SparseWeights, Flatten, rezeroWeights, updateBoostStrength)\n", "\n", "sparseCNN = nn.Sequential(\n", " # Sparse CNN layer\n", @@ -255,20 +256,55 @@ "cell_type": "code", "execution_count": 8, "metadata": {}, + "outputs": [], + "source": [ + "sgd = optim.SGD(sparseCNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", + "train(model=sparseCNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After each epoch we rezero the weights to keep the initial sparsity constant during training. We also apply the boost strength factor after each epoch" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "sparseCNN.apply(rezeroWeights)\n", + "sparseCNN.apply(updateBoostStrength)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test and print results" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'accuracy': 0.978, 'loss': 0.06757185096740723, 'total_correct': 9780}\n" - ] + "data": { + "text/plain": [ + "{'accuracy': 0.9782, 'loss': 0.06787856979370117, 'total_correct': 9782}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "sgd = optim.SGD(sparseCNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", - "train(model=sparseCNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)\n", - "results = test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)\n", - "print(results)" + "test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)" ] }, { @@ -280,28 +316,30 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'accuracy': 0.9856, 'loss': 0.04167602500915527, 'total_correct': 9856}\n", - "{'accuracy': 0.9868, 'loss': 0.040336785697937014, 'total_correct': 9868}\n", - "{'accuracy': 0.9872, 'loss': 0.03909029731750488, 'total_correct': 9872}\n", - "{'accuracy': 0.9874, 'loss': 0.037309212684631346, 'total_correct': 9874}\n", - "{'accuracy': 0.9876, 'loss': 0.037822017288208006, 'total_correct': 9876}\n", - "{'accuracy': 0.9877, 'loss': 0.03768303909301758, 'total_correct': 9877}\n", - "{'accuracy': 0.988, 'loss': 0.03783873291015625, 'total_correct': 9880}\n", - "{'accuracy': 0.9881, 'loss': 0.038252718925476076, 'total_correct': 9881}\n", - "{'accuracy': 0.9873, 'loss': 0.03846522216796875, 'total_correct': 9873}\n" + "{'accuracy': 0.9862, 'loss': 0.0412615779876709, 'total_correct': 9862}\n", + "{'accuracy': 0.9868, 'loss': 0.04029187545776367, 'total_correct': 9868}\n", + "{'accuracy': 0.9867, 'loss': 0.03934368209838867, 'total_correct': 9867}\n", + "{'accuracy': 0.9876, 'loss': 0.03759277114868164, 'total_correct': 9876}\n", + "{'accuracy': 0.9877, 'loss': 0.03777754402160644, 'total_correct': 9877}\n", + "{'accuracy': 0.9872, 'loss': 0.03784116630554199, 'total_correct': 9872}\n", + "{'accuracy': 0.9872, 'loss': 0.03829168930053711, 'total_correct': 9872}\n", + "{'accuracy': 0.9876, 'loss': 0.03837165260314941, 'total_correct': 9876}\n", + "{'accuracy': 0.9871, 'loss': 0.03919747161865234, 'total_correct': 9871}\n" ] } ], "source": [ "for epoch in range(1, EPOCHS):\n", " train(model=sparseCNN, loader=train_loader, optimizer=sgd, criterion=F.nll_loss)\n", + " sparseCNN.apply(rezeroWeights)\n", + " sparseCNN.apply(updateBoostStrength)\n", " results = test(model=sparseCNN, loader=test_loader, criterion=F.nll_loss)\n", " print(results)" ] @@ -316,18 +354,18 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0.05 : {'accuracy': 0.9842, 'loss': 0.04920544128417969, 'total_correct': 9842}\n", - "0.1 : {'accuracy': 0.9782, 'loss': 0.067373148727417, 'total_correct': 9782}\n", - "0.15 : {'accuracy': 0.9697, 'loss': 0.0976463264465332, 'total_correct': 9697}\n", - "0.2 : {'accuracy': 0.9524, 'loss': 0.15406905670166016, 'total_correct': 9524}\n", - "0.25 : {'accuracy': 0.9238, 'loss': 0.23606297912597657, 'total_correct': 9238}\n" + "0.05 : {'accuracy': 0.9841, 'loss': 0.05029125137329102, 'total_correct': 9841}\n", + "0.1 : {'accuracy': 0.9777, 'loss': 0.06853677139282227, 'total_correct': 9777}\n", + "0.15 : {'accuracy': 0.9697, 'loss': 0.09935387115478515, 'total_correct': 9697}\n", + "0.2 : {'accuracy': 0.9509, 'loss': 0.1598887924194336, 'total_correct': 9509}\n", + "0.25 : {'accuracy': 0.9225, 'loss': 0.2412475631713867, 'total_correct': 9225}\n" ] } ], @@ -359,7 +397,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.7.3" } }, "nbformat": 4, diff --git a/examples/sparse_linear.ipynb b/examples/sparse_linear.ipynb index d14a4bf..0a70d03 100644 --- a/examples/sparse_linear.ipynb +++ b/examples/sparse_linear.ipynb @@ -30,8 +30,10 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, + "execution_count": null, + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ "# Uncomment the following line to install nupic.torch\n", @@ -188,7 +190,8 @@ "metadata": {}, "outputs": [], "source": [ - "from nupic.torch.modules import KWinners, SparseWeights, Flatten\n", + "from nupic.torch.modules import (\n", + " KWinners, SparseWeights, Flatten, rezeroWeights, updateBoostStrength)\n", "\n", "sparseNN = nn.Sequential(\n", " Flatten(),\n", @@ -244,53 +247,90 @@ "cell_type": "code", "execution_count": 8, "metadata": {}, + "outputs": [], + "source": [ + "sgd = optim.SGD(sparseNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", + "train(model=sparseNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After each epoch we rezero the weights to keep the initial sparsity constant during training. We also apply the boost strength factor after each epoch" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "sparseNN.apply(rezeroWeights)\n", + "sparseNN.apply(updateBoostStrength)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test and print results" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'accuracy': 0.9506, 'loss': 0.16315889587402344, 'total_correct': 9506}\n" - ] + "data": { + "text/plain": [ + "{'accuracy': 0.9486, 'loss': 0.16556934661865233, 'total_correct': 9486}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "sgd = optim.SGD(sparseNN.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", - "train(model=sparseNN, loader=first_loader, optimizer=sgd, criterion=F.nll_loss)\n", - "results = test(model=sparseNN, loader=test_loader, criterion=F.nll_loss)\n", - "print(results)" + "test(model=sparseNN, loader=test_loader, criterion=F.nll_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "At this point the duty cycles should be stable and we can train on larger batch sizes" + "At this point the duty cycles should be stable and we can train the rest of the epochs on larger batch sizes" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'accuracy': 0.9627, 'loss': 0.12201591415405273, 'total_correct': 9627}\n", - "{'accuracy': 0.9634, 'loss': 0.12190820617675781, 'total_correct': 9634}\n", - "{'accuracy': 0.9623, 'loss': 0.12345575714111329, 'total_correct': 9623}\n", - "{'accuracy': 0.9639, 'loss': 0.1185587173461914, 'total_correct': 9639}\n", - "{'accuracy': 0.9611, 'loss': 0.11994301071166992, 'total_correct': 9611}\n", - "{'accuracy': 0.9633, 'loss': 0.11600606689453125, 'total_correct': 9633}\n", - "{'accuracy': 0.9634, 'loss': 0.11699238586425781, 'total_correct': 9634}\n", - "{'accuracy': 0.9639, 'loss': 0.11530724716186523, 'total_correct': 9639}\n", - "{'accuracy': 0.9633, 'loss': 0.11797227020263672, 'total_correct': 9633}\n" + "{'accuracy': 0.9621, 'loss': 0.12309523391723633, 'total_correct': 9621}\n", + "{'accuracy': 0.9625, 'loss': 0.12451380462646484, 'total_correct': 9625}\n", + "{'accuracy': 0.9621, 'loss': 0.12468773880004883, 'total_correct': 9621}\n", + "{'accuracy': 0.9638, 'loss': 0.11706881561279296, 'total_correct': 9638}\n", + "{'accuracy': 0.9623, 'loss': 0.120688623046875, 'total_correct': 9623}\n", + "{'accuracy': 0.9645, 'loss': 0.11490174255371094, 'total_correct': 9645}\n", + "{'accuracy': 0.9648, 'loss': 0.1163398452758789, 'total_correct': 9648}\n", + "{'accuracy': 0.9643, 'loss': 0.1144802864074707, 'total_correct': 9643}\n", + "{'accuracy': 0.9657, 'loss': 0.11591439743041992, 'total_correct': 9657}\n" ] } ], "source": [ "for epoch in range(1, EPOCHS):\n", " train(model=sparseNN, loader=train_loader, optimizer=sgd, criterion=F.nll_loss)\n", + " sparseNN.apply(updateBoostStrength)\n", + " sparseNN.apply(rezeroWeights)\n", " results = test(model=sparseNN, loader=test_loader, criterion=F.nll_loss)\n", " print(results)" ] @@ -305,18 +345,20 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, + "execution_count": 12, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0.05 : {'accuracy': 0.9541, 'loss': 0.14802040100097658, 'total_correct': 9541}\n", - "0.1 : {'accuracy': 0.939, 'loss': 0.1916733642578125, 'total_correct': 9390}\n", - "0.15 : {'accuracy': 0.9094, 'loss': 0.2852293640136719, 'total_correct': 9094}\n", - "0.2 : {'accuracy': 0.8639, 'loss': 0.4125948455810547, 'total_correct': 8639}\n", - "0.25 : {'accuracy': 0.8043, 'loss': 0.5801907653808593, 'total_correct': 8043}\n" + "0.05 : {'accuracy': 0.9563, 'loss': 0.14633502197265624, 'total_correct': 9563}\n", + "0.1 : {'accuracy': 0.9406, 'loss': 0.19336406707763673, 'total_correct': 9406}\n", + "0.15 : {'accuracy': 0.9114, 'loss': 0.28374275512695313, 'total_correct': 9114}\n", + "0.2 : {'accuracy': 0.8656, 'loss': 0.4086779327392578, 'total_correct': 8656}\n", + "0.25 : {'accuracy': 0.8082, 'loss': 0.5770498291015626, 'total_correct': 8082}\n" ] } ], @@ -348,7 +390,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.7.3" } }, "nbformat": 4,