Skip to content

Commit

Permalink
Graph Tsetlin Machine
Browse files Browse the repository at this point in the history
  • Loading branch information
olegranmo committed Aug 25, 2024
1 parent 14709bc commit 4ea5e2e
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 168 deletions.
28 changes: 16 additions & 12 deletions examples/classification/CIFAR2Demo3x3LiteralBudget.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

_LOGGER = logging.getLogger(__name__)

logging.basicConfig(level=logging.INFO)

def preprocess_cifar10_data(resolution, animals):
"""
Expand All @@ -35,9 +36,9 @@ def preprocess_cifar10_data(resolution, animals):
# Initialize empty arrays for quantized images
X_train = np.empty(
(X_train_org.shape[0], X_train_org.shape[1], X_train_org.shape[2], X_train_org.shape[3], resolution),
dtype=np.uint8)
dtype=np.uint32)
X_test = np.empty((X_test_org.shape[0], X_test_org.shape[1], X_test_org.shape[2], X_test_org.shape[3], resolution),
dtype=np.uint8)
dtype=np.uint32)

# Quantize pixel values
for z in range(resolution):
Expand Down Expand Up @@ -73,17 +74,18 @@ def run_ensemble(ensemble_params):
# Unpack parameters
args, X_train, Y_train, X_test, Y_test, ensemble = ensemble_params

T = int(args.clauses * 0.75)
tm = TMClassifier(
args.clauses,
T,
args.T,
args.s,
platform=args.platform,
patch_dim=(args.patch_size, args.patch_size),
number_of_state_bits_ta=args.number_of_state_bits_ta,
weighted_clauses=args.weighted_clauses,
literal_drop_p=args.literal_drop_p,
max_included_literals=args.max_included_literals
max_included_literals=args.max_included_literals,
spatio_temporal=True,
depth=args.depth
)

ensemble_results = metrics(args)
Expand Down Expand Up @@ -149,18 +151,20 @@ def main(args):

def default_args(**kwargs):
parser = argparse.ArgumentParser()
parser.add_argument("--max_included_literals", type=int, default=32)
parser.add_argument("--clauses", type=int, default=8000)
parser.add_argument("--max-included-literals", type=int, default=32)
parser.add_argument("--clauses", type=int, default=100)
parser.add_argument("--T", type=int, default=750)
parser.add_argument("--s", type=float, default=10.0)
parser.add_argument("--platform", type=str, default="GPU")
parser.add_argument("--patch_size", type=int, default=3)
parser.add_argument("--patch-size", type=int, default=3)
parser.add_argument("--resolution", type=int, default=8)
parser.add_argument("--number_of_state_bits_ta", type=int, default=8)
parser.add_argument("--literal_drop_p", type=float, default=0.0)
parser.add_argument("--number-of-state-bits-ta", type=int, default=8)
parser.add_argument("--literal-drop-p", type=float, default=0.0)
parser.add_argument("--depth", type=int, default=1)
parser.add_argument("--epochs", type=int, default=250)
parser.add_argument("--ensembles", type=int, default=5)
parser.add_argument("--ensembles", type=int, default=1)
parser.add_argument("--weighted-clauses", type=bool, default=True)
parser.add_argument("--use_multiprocessing", action='store_true', help="Use multiprocessing to run ensembles in parallel")
parser.add_argument("--use-multiprocessing", action='store_false', help="Use multiprocessing to run ensembles in parallel")
args = parser.parse_args()
for key, value in kwargs.items():
if key in args.__dict__:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def default_args(**kwargs):
parser.add_argument("--platform", default='CPU', type=str)
parser.add_argument("--T", default=100*2, type=int)
parser.add_argument("--s", default=1.0, type=float)
parser.add_argument("--sequence-length", default=6, type=int)
parser.add_argument("--sequence-length", default=10, type=int)
parser.add_argument("--noise", default=0.01, type=float)
parser.add_argument("--examples", default=40000, type=int)
parser.add_argument("--depth", default=2, type=int)
Expand Down
14 changes: 7 additions & 7 deletions examples/classification/SequenceInterpretabilityDemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def main(args):
position_2 = position_1+1
position_3 = position_1+2

# position_2 = np.random.randint(position_1+1, args.sequence_length-1)
# position_3 = np.random.randint(position_2+1, args.sequence_length)
#position_2 = np.random.randint(position_1+1, args.sequence_length-1)
#position_3 = np.random.randint(position_2+1, args.sequence_length)

if Y_train[i] == 0:
X_train[i,0,position_1,0] = 1
Expand Down Expand Up @@ -72,8 +72,8 @@ def main(args):
position_2 = position_1+1
position_3 = position_1+2

# position_2 = np.random.randint(position_1+1, args.sequence_length-1)
# position_3 = np.random.randint(position_2+1, args.sequence_length)
#position_2 = np.random.randint(position_1+1, args.sequence_length-1)
#position_3 = np.random.randint(position_2+1, args.sequence_length)

if Y_test[i] == 0:
X_test[i,0,position_1,0] = 1
Expand Down Expand Up @@ -215,10 +215,10 @@ def default_args(**kwargs):
parser.add_argument("--platform", default='CPU', type=str)
parser.add_argument("--T", default=100*2, type=int)
parser.add_argument("--s", default=1.0, type=float)
parser.add_argument("--sequence-length", default=6, type=int)
parser.add_argument("--noise", default=0.01, type=float, help="Noisy XOR")
parser.add_argument("--sequence-length", default=10, type=int)
parser.add_argument("--noise", default=0.0, type=float, help="Noisy XOR")
parser.add_argument("--examples", default=40000, type=int, help="Noisy XOR")
parser.add_argument("--depth", default=2, type=int)
parser.add_argument("--depth", default=1, type=int)
parser.add_argument("--number-of-state-bits-ta", default=10, type=int)
parser.add_argument("--max-included-literals", default=32, type=int)

Expand Down
13 changes: 9 additions & 4 deletions tmu/clause_bank/base_clause_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(
self.type_ia_ii_feedback_ratio = type_ia_ii_feedback_ratio
self.spatio_temporal = spatio_temporal
self.depth = depth
self.hypervector_size = 256
self.hypervector_bits = 2

if len(X_shape) == 2:
self.dim = (X_shape[1], 1, 1)
Expand All @@ -53,17 +55,20 @@ def __init__(
if self.patch_dim is None:
self.patch_dim = (self.dim[0] * self.dim[1] * self.dim[2], 1)

self.number_of_features = int(
self.patch_dim[0] * self.patch_dim[1] * self.dim[2] + (self.dim[0] - self.patch_dim[0]) + (
self.dim[1] - self.patch_dim[1]))
self.number_of_input_features = int(self.patch_dim[0] * self.patch_dim[1] * self.dim[2])

self.number_of_features = int(self.patch_dim[0] * self.patch_dim[1] * self.dim[2])

self.number_of_patches = int((self.dim[0] - self.patch_dim[0] + 1) * (self.dim[1] - self.patch_dim[1] + 1))

if self.spatio_temporal:
self.number_of_features += self.number_of_clauses*4*self.depth;
self.number_of_features += self.depth*self.hypervector_size

self.number_of_input_literals = self.number_of_input_features * 2

self.number_of_literals = self.number_of_features * 2

self.number_of_input_ta_chunks = int((self.number_of_input_literals - 1) / 32 + 1)
self.number_of_ta_chunks = int((self.number_of_literals - 1) / 32 + 1)

self.max_included_literals = max_included_literals if max_included_literals else self.number_of_literals
Expand Down
81 changes: 61 additions & 20 deletions tmu/clause_bank/clause_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(
self.type_ia_feedback_counter = np.zeros(self.number_of_clauses, dtype=np.uint32, order="c")

if self.spatio_temporal:
self.xi_hypervector = np.empty(self.number_of_patches * self.number_of_ta_chunks, dtype=np.uint32, order="c")

self.clause_value_in_patch = np.empty(self.number_of_patches * self.number_of_clauses, dtype=np.uint32, order="c")
self.clause_value_in_patch_tmp = np.empty(self.number_of_patches * self.number_of_clauses, dtype=np.uint32, order="c")

Expand All @@ -79,6 +81,12 @@ def __init__(

self.attention = np.empty(self.number_of_ta_chunks, dtype=np.uint32, order="c")

self.hypervectors = np.empty((self.number_of_clauses, self.hypervector_bits), dtype=np.uint32, order="c")
indexes = np.arange(self.hypervector_size, dtype=np.uint32)
for i in range(self.number_of_clauses):
self.hypervectors[i,:] = np.random.choice(indexes, size=(self.hypervector_bits), replace=False)
self.hypervectors = self.hypervectors.reshape(self.number_of_clauses*self.hypervector_bits)

# Incremental Clause Evaluation
self.literal_clause_map = np.empty(
(int(self.number_of_literals * self.number_of_clauses)),
Expand Down Expand Up @@ -122,7 +130,8 @@ def _cffi_init(self):
self.ptr_output_one_patches = ffi.cast("unsigned int *", self.output_one_patches.ctypes.data)
self.ptr_literal_clause_count = ffi.cast("unsigned int *", self.literal_clause_count.ctypes.data)
self.tiafc_p = ffi.cast("unsigned int *", self.type_ia_feedback_counter.ctypes.data)

self.xih_p = ffi.cast("unsigned int *", self.xi_hypervector.ctypes.data)

if self.spatio_temporal:
self.cvip_p = ffi.cast("unsigned int *", self.clause_value_in_patch.ctypes.data)
self.cvipt_p = ffi.cast("unsigned int *", self.clause_value_in_patch_tmp.ctypes.data)
Expand All @@ -135,6 +144,7 @@ def _cffi_init(self):
self.ctvtl_p = ffi.cast("unsigned int *", self.clause_truth_value_transitions_length.ctypes.data)

self.a_p = ffi.cast("unsigned int *", self.attention.ctypes.data)
self.hv_p = ffi.cast("unsigned int *", self.hypervectors.ctypes.data)

# Clause Initialization
self.ptr_ta_state = ffi.cast("unsigned int *", self.clause_bank.ctypes.data)
Expand Down Expand Up @@ -177,13 +187,24 @@ def calculate_clause_outputs_predict(self, encoded_X, e):

if not self.incremental or self.spatio_temporal:
if self.spatio_temporal:
lib.cb_prepare_hypervector(
self.number_of_input_features,
self.number_of_patches,
self.hypervector_size,
self.depth,
xi_p,
self.xih_p
)

lib.cb_calculate_spatio_temporal_features(
self.ptr_ta_state,
self.number_of_clauses,
self.number_of_literals,
self.number_of_features,
self.number_of_state_bits_ta,
self.number_of_patches,
self.depth,
self.hypervector_size,
self.hypervector_bits,
self.cvip_p,
self.cvipt_p,
self.ctc_p,
Expand All @@ -192,7 +213,8 @@ def calculate_clause_outputs_predict(self, encoded_X, e):
self.ctvt_p,
self.ctvtl_p,
self.a_p,
xi_p
self.hv_p,
self.xih_p
)

lib.cb_calculate_clause_outputs_predict_spatio_temporal(
Expand All @@ -208,7 +230,7 @@ def calculate_clause_outputs_predict(self, encoded_X, e):
self.cfcb_p,
self.ctvt_p,
self.ctvtl_p,
xi_p
self.xih_p
)
else:
lib.cb_calculate_clause_outputs_predict(
Expand All @@ -218,7 +240,7 @@ def calculate_clause_outputs_predict(self, encoded_X, e):
self.number_of_state_bits_ta,
self.number_of_patches,
self.co_p,
xi_p
self.xih_p
)
return self.clause_output

Expand Down Expand Up @@ -258,13 +280,24 @@ def calculate_clause_outputs_update(self, literal_active, encoded_X, e):
la_p = ffi.cast("unsigned int *", literal_active.ctypes.data)

if self.spatio_temporal:
lib.cb_prepare_hypervector(
self.number_of_input_features,
self.number_of_patches,
self.hypervector_size,
self.depth,
xi_p,
self.xih_p
)

lib.cb_calculate_spatio_temporal_features(
self.ptr_ta_state,
self.number_of_clauses,
self.number_of_literals,
self.number_of_features,
self.number_of_state_bits_ta,
self.number_of_patches,
self.depth,
self.hypervector_size,
self.hypervector_bits,
self.cvip_p,
self.cvipt_p,
self.ctc_p,
Expand All @@ -273,7 +306,8 @@ def calculate_clause_outputs_update(self, literal_active, encoded_X, e):
self.ctvt_p,
self.ctvtl_p,
self.a_p,
xi_p
self.hv_p,
self.xih_p
)

lib.cb_calculate_clause_outputs_update_spatio_temporal(
Expand All @@ -290,7 +324,7 @@ def calculate_clause_outputs_update(self, literal_active, encoded_X, e):
self.cfcb_p,
self.ctvt_p,
self.ctvtl_p,
xi_p
self.xih_p
)
else:
lib.cb_calculate_clause_outputs_update(
Expand All @@ -310,13 +344,24 @@ def calculate_clause_outputs_patchwise(self, encoded_X, e):
xi_p = ffi.cast("unsigned int *", encoded_X[e, :].ctypes.data)

if self.spatio_temporal:
lib.cb_prepare_hypervector(
self.number_of_input_features,
self.number_of_patches,
self.hypervector_size,
self.depth,
xi_p,
self.xih_p
)

lib.cb_calculate_spatio_temporal_features(
self.ptr_ta_state,
self.number_of_clauses,
self.number_of_literals,
self.number_of_features,
self.number_of_state_bits_ta,
self.number_of_patches,
self.depth,
self.hypervector_size,
self.hypervector_bits,
self.cvip_p,
self.cvipt_p,
self.ctc_p,
Expand All @@ -325,7 +370,8 @@ def calculate_clause_outputs_patchwise(self, encoded_X, e):
self.ctvt_p,
self.ctvtl_p,
self.a_p,
xi_p
self.hv_p,
self.xih_p
)

lib.cb_calculate_clause_outputs_patchwise(
Expand Down Expand Up @@ -374,7 +420,7 @@ def type_i_feedback(
self.cfcb_p,
self.ctvt_p,
self.ctvtl_p,
ptr_xi
self.xih_p
)
else:
lib.cb_type_i_feedback(
Expand Down Expand Up @@ -426,7 +472,7 @@ def type_ii_feedback(
self.cfcb_p,
self.ctvt_p,
self.ctvtl_p,
ptr_xi
self.xih_p
)
else:
lib.cb_type_ii_feedback(
Expand Down Expand Up @@ -543,19 +589,14 @@ def prepare_X(
self,
X
):
if self.spatio_temporal:
spatio_temporal_features = self.number_of_clauses*4*self.depth
else:
spatio_temporal_features = 0

return tmu.tools.encode(
X,
X.shape[0],
self.number_of_patches,
self.number_of_ta_chunks,
self.number_of_input_ta_chunks,
self.dim,
self.patch_dim,
spatio_temporal_features
0
)

def prepare_X_autoencoder(
Expand All @@ -564,7 +605,7 @@ def prepare_X_autoencoder(
X_csc,
active_output
):
X = np.ascontiguousarray(np.empty(int(self.number_of_ta_chunks), dtype=np.uint32))
X = np.ascontiguousarray(np.empty(int(self.number_of_input_ta_chunks), dtype=np.uint32))
return X_csr, X_csc, active_output, X

def produce_autoencoder_example(
Expand Down
Loading

0 comments on commit 4ea5e2e

Please sign in to comment.