From 14709bcc5e12160419acb6f4cb6f0527d8e0d8c6 Mon Sep 17 00:00:00 2001 From: Ole-Christoffer Granmo Date: Sat, 17 Aug 2024 16:53:10 +0200 Subject: [PATCH] Update --- .../classification/SequenceInterpretabilityDemo.py | 13 ++++++++----- tmu/lib/src/ClauseBank.c | 6 +++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/classification/SequenceInterpretabilityDemo.py b/examples/classification/SequenceInterpretabilityDemo.py index 45023e84..bbfb139f 100644 --- a/examples/classification/SequenceInterpretabilityDemo.py +++ b/examples/classification/SequenceInterpretabilityDemo.py @@ -31,8 +31,8 @@ def main(args): position_2 = position_1+1 position_3 = position_1+2 - #position_2 = np.random.randint(position_1+1, args.sequence_length-1) - #position_3 = np.random.randint(position_2+1, args.sequence_length) +# position_2 = np.random.randint(position_1+1, args.sequence_length-1) +# position_3 = np.random.randint(position_2+1, args.sequence_length) if Y_train[i] == 0: X_train[i,0,position_1,0] = 1 @@ -72,8 +72,8 @@ def main(args): position_2 = position_1+1 position_3 = position_1+2 - #position_2 = np.random.randint(position_1+1, args.sequence_length-1) - #position_3 = np.random.randint(position_2+1, args.sequence_length) +# position_2 = np.random.randint(position_1+1, args.sequence_length-1) +# position_3 = np.random.randint(position_2+1, args.sequence_length) if Y_test[i] == 0: X_test[i,0,position_1,0] = 1 @@ -103,7 +103,7 @@ def main(args): X_test[i,0,position_3,0] = 1 X_test[i,0,position_3,1] = 0 - tm = TMClassifier(args.number_of_clauses, args.T, args.s, number_of_state_bits_ta=10, patch_dim=(1, 1), weighted_clauses=True, platform=args.platform, boost_true_positive_feedback=True, spatio_temporal=True, incremental=False, max_included_literals=32) + tm = TMClassifier(args.number_of_clauses, args.T, args.s, number_of_state_bits_ta=args.number_of_state_bits_ta, patch_dim=(1, 1), weighted_clauses=True, platform=args.platform, boost_true_positive_feedback=True, spatio_temporal=True, incremental=False, max_included_literals=args.max_included_literals, depth=args.depth) for i in range(args.epochs): tm.fit(X_train, Y_train) @@ -218,6 +218,9 @@ def default_args(**kwargs): parser.add_argument("--sequence-length", default=6, type=int) parser.add_argument("--noise", default=0.01, type=float, help="Noisy XOR") parser.add_argument("--examples", default=40000, type=int, help="Noisy XOR") + parser.add_argument("--depth", default=2, type=int) + parser.add_argument("--number-of-state-bits-ta", default=10, type=int) + parser.add_argument("--max-included-literals", default=32, type=int) args = parser.parse_args() for key, value in kwargs.items(): diff --git a/tmu/lib/src/ClauseBank.c b/tmu/lib/src/ClauseBank.c index b1b81114..86ddd2c5 100644 --- a/tmu/lib/src/ClauseBank.c +++ b/tmu/lib/src/ClauseBank.c @@ -1529,7 +1529,7 @@ void cb_calculate_spatio_temporal_features( ); // Just after - if (patch > 0) { + if (patch > 0 && !(d % 2)) { if (clause_output) { chunk_nr = (number_of_clauses*4*d + j) / 32; chunk_pos = (number_of_clauses*4*d + j) % 32; @@ -1542,7 +1542,7 @@ void cb_calculate_spatio_temporal_features( } // Just before - if (patch < number_of_patches - 1) { + if (patch < number_of_patches - 1 && !(d % 2)) { if (clause_output) { chunk_nr = (number_of_clauses*4*d + j + number_of_clauses) / 32; chunk_pos = (number_of_clauses*4*d + j + number_of_clauses) % 32; @@ -1554,7 +1554,7 @@ void cb_calculate_spatio_temporal_features( } } - if (clause_output) { + if (clause_output && (d % 2)) { // After for (int patch_before = 0; patch_before < patch; ++patch_before) {