Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
olegranmo committed Aug 17, 2024
1 parent 5319b4e commit 14709bc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
13 changes: 8 additions & 5 deletions examples/classification/SequenceInterpretabilityDemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def main(args):
position_2 = position_1+1
position_3 = position_1+2

#position_2 = np.random.randint(position_1+1, args.sequence_length-1)
#position_3 = np.random.randint(position_2+1, args.sequence_length)
# position_2 = np.random.randint(position_1+1, args.sequence_length-1)
# position_3 = np.random.randint(position_2+1, args.sequence_length)

if Y_train[i] == 0:
X_train[i,0,position_1,0] = 1
Expand Down Expand Up @@ -72,8 +72,8 @@ def main(args):
position_2 = position_1+1
position_3 = position_1+2

#position_2 = np.random.randint(position_1+1, args.sequence_length-1)
#position_3 = np.random.randint(position_2+1, args.sequence_length)
# position_2 = np.random.randint(position_1+1, args.sequence_length-1)
# position_3 = np.random.randint(position_2+1, args.sequence_length)

if Y_test[i] == 0:
X_test[i,0,position_1,0] = 1
Expand Down Expand Up @@ -103,7 +103,7 @@ def main(args):
X_test[i,0,position_3,0] = 1
X_test[i,0,position_3,1] = 0

tm = TMClassifier(args.number_of_clauses, args.T, args.s, number_of_state_bits_ta=10, patch_dim=(1, 1), weighted_clauses=True, platform=args.platform, boost_true_positive_feedback=True, spatio_temporal=True, incremental=False, max_included_literals=32)
tm = TMClassifier(args.number_of_clauses, args.T, args.s, number_of_state_bits_ta=args.number_of_state_bits_ta, patch_dim=(1, 1), weighted_clauses=True, platform=args.platform, boost_true_positive_feedback=True, spatio_temporal=True, incremental=False, max_included_literals=args.max_included_literals, depth=args.depth)

for i in range(args.epochs):
tm.fit(X_train, Y_train)
Expand Down Expand Up @@ -218,6 +218,9 @@ def default_args(**kwargs):
parser.add_argument("--sequence-length", default=6, type=int)
parser.add_argument("--noise", default=0.01, type=float, help="Noisy XOR")
parser.add_argument("--examples", default=40000, type=int, help="Noisy XOR")
parser.add_argument("--depth", default=2, type=int)
parser.add_argument("--number-of-state-bits-ta", default=10, type=int)
parser.add_argument("--max-included-literals", default=32, type=int)

args = parser.parse_args()
for key, value in kwargs.items():
Expand Down
6 changes: 3 additions & 3 deletions tmu/lib/src/ClauseBank.c
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ void cb_calculate_spatio_temporal_features(
);

// Just after
if (patch > 0) {
if (patch > 0 && !(d % 2)) {
if (clause_output) {
chunk_nr = (number_of_clauses*4*d + j) / 32;
chunk_pos = (number_of_clauses*4*d + j) % 32;
Expand All @@ -1542,7 +1542,7 @@ void cb_calculate_spatio_temporal_features(
}

// Just before
if (patch < number_of_patches - 1) {
if (patch < number_of_patches - 1 && !(d % 2)) {
if (clause_output) {
chunk_nr = (number_of_clauses*4*d + j + number_of_clauses) / 32;
chunk_pos = (number_of_clauses*4*d + j + number_of_clauses) % 32;
Expand All @@ -1554,7 +1554,7 @@ void cb_calculate_spatio_temporal_features(
}
}

if (clause_output) {
if (clause_output && (d % 2)) {
// After

for (int patch_before = 0; patch_before < patch; ++patch_before) {
Expand Down

0 comments on commit 14709bc

Please sign in to comment.