Skip to content

Commit 8638ba3

Browse files
authored
Merge pull request #7 from lscheinkman/add_test
Add duty_cycle_metrics_test
2 parents a870800 + dd2b182 commit 8638ba3

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

tests/duty_cycle_metrics_test.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# ----------------------------------------------------------------------
2+
# Numenta Platform for Intelligent Computing (NuPIC)
3+
# Copyright (C) 2019, Numenta, Inc. Unless you have an agreement
4+
# with Numenta, Inc., for a separate license for this software code, the
5+
# following terms and conditions apply:
6+
#
7+
# This program is free software: you can redistribute it and/or modify
8+
# it under the terms of the GNU Affero Public License version 3 as
9+
# published by the Free Software Foundation.
10+
#
11+
# This program is distributed in the hope that it will be useful,
12+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
14+
# See the GNU Affero Public License for more details.
15+
#
16+
# You should have received a copy of the GNU Affero Public License
17+
# along with this program. If not, see http://www.gnu.org/licenses.
18+
#
19+
# http://numenta.org/licenses/
20+
# ----------------------------------------------------------------------
21+
22+
import unittest
23+
24+
import torch
25+
26+
from nupic.torch.duty_cycle_metrics import \
27+
binaryEntropy, maxEntropy
28+
29+
30+
class DutyCycleMetricsTest(unittest.TestCase):
31+
"""
32+
Simplistic tests of duty cycle entropy metrics
33+
"""
34+
35+
def testBinaryEntropy(self):
36+
37+
p = torch.tensor([0.1, 0.02, 0.99, 0.5, 0.75, 0.8, 0.3, 0.4, 0.0, 1.0])
38+
entropy, entropySum = binaryEntropy(p)
39+
self.assertAlmostEqual(entropySum.item(), 5.076676985, places=4)
40+
self.assertAlmostEqual(entropySum.item(), entropy.sum(), places=4)
41+
self.assertAlmostEqual(entropy[0].item(), 0.468995594, places=4)
42+
self.assertAlmostEqual(entropy[1].item(), 0.141440543, places=4)
43+
self.assertAlmostEqual(entropy[2].item(), 0.080793136, places=4)
44+
self.assertEqual(entropy[8].item(), 0.0)
45+
self.assertEqual(entropy[9].item(), 0.0)
46+
47+
p = torch.tensor([0.25, 0.25, 0.25, 0.25])
48+
entropy, entropySum = binaryEntropy(p)
49+
self.assertAlmostEqual(entropySum, 3.245112498, places=4)
50+
self.assertAlmostEqual(entropySum, entropy.sum(), places=4)
51+
52+
p = torch.tensor([0.5, 0.5, 0.5, 0.5])
53+
entropy, entropySum = binaryEntropy(p)
54+
self.assertAlmostEqual(entropySum, 4.0, places=4)
55+
self.assertAlmostEqual(entropySum, entropy.sum(), places=4)
56+
self.assertAlmostEqual(entropy[0], 1.0, places=4)
57+
self.assertAlmostEqual(entropy[1], 1.0, places=4)
58+
self.assertAlmostEqual(entropy[2], 1.0, places=4)
59+
self.assertAlmostEqual(entropy[3], 1.0, places=4)
60+
61+
62+
def testMaxEntropy(self):
63+
64+
entropy = maxEntropy(1,1)
65+
self.assertAlmostEqual(entropy, 0.0, places=4)
66+
67+
entropy = maxEntropy(1,0)
68+
self.assertAlmostEqual(entropy, 0.0, places=4)
69+
70+
entropy = maxEntropy(4,1)
71+
self.assertAlmostEqual(entropy, 3.245112498, places=4)
72+
73+
entropy = maxEntropy(4,2)
74+
self.assertAlmostEqual(entropy, 4.0, places=4)
75+
76+
entropy = maxEntropy(100,1)
77+
self.assertAlmostEqual(entropy, 8.07931359, places=4)
78+
79+
entropy = maxEntropy(100,10)
80+
self.assertAlmostEqual(entropy, 46.89955936, places=4)
81+
82+
entropy = maxEntropy(2048, 40)
83+
self.assertAlmostEqual(entropy, 284.2634199, places=4)
84+
85+
86+
if __name__ == "__main__":
87+
unittest.main()

tests/k_winners_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def testOne(self):
104104
self.assertEqual(
105105
(grad_x[indices] == self.gradient.reshape(-1)[indices]).sum(), 4)
106106
self.assertAlmostEqual(
107-
grad_x.sum(), self.gradient.reshape(-1)[indices].sum(), places=4)
107+
grad_x.sum().item(), self.gradient.reshape(-1)[indices].sum().item(), places=4)
108108
self.assertEqual(len(grad_x.nonzero()), 4)
109109

110110

0 commit comments

Comments
 (0)