From 3d61d42feeccd4234b776f5c057f0cac6e605b78 Mon Sep 17 00:00:00 2001 From: Ole-Christoffer Granmo Date: Tue, 8 Oct 2024 14:42:25 +0200 Subject: [PATCH] Update --- .../cuda/calculate_clause_value_in_patch.cu | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tmu/clause_bank/cuda/calculate_clause_value_in_patch.cu diff --git a/tmu/clause_bank/cuda/calculate_clause_value_in_patch.cu b/tmu/clause_bank/cuda/calculate_clause_value_in_patch.cu new file mode 100644 index 00000000..28c49488 --- /dev/null +++ b/tmu/clause_bank/cuda/calculate_clause_value_in_patch.cu @@ -0,0 +1,93 @@ +/*** +# Copyright (c) 2021 Ole-Christoffer Granmo + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This code implements the Convolutional Tsetlin Machine from paper arXiv:1905.09688 +# https://arxiv.org/abs/1905.09688 +***/ + +#include + +extern "C" +{ + __global__ void calculate_clause_value_in_patch( + unsigned int *ta_state, + int number_of_clauses, + int number_of_literals, + int number_of_state_bits, + unsigned int *global_ta_state, + int *global_clause_node_output, + unsigned int *literal_active, + unsigned int *global_X, + int e + ) + { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + unsigned int clause_node_output; + + unsigned int filter; + if (((number_of_literals) % 32) != 0) { + filter = (~(0xffffffff << ((number_of_literals) % 32))); + } else { + filter = 0xffffffff; + } + unsigned int number_of_ta_chunks = (number_of_literals-1)/32 + 1; + + int number_of_node_chunks = (NUMBER_OF_PATCHES - 1)/32 + 1; + unsigned int node_filter; + if ((NUMBER_OF_PATCHES % 32) != 0) { + node_filter = (~(0xffffffff << (NUMBER_OF_PATCHES % 32))); + } else { + node_filter = 0xffffffff; + } + + unsigned int *X = &global_X[e * number_of_ta_chunks * NUMBER_OF_PATCHES]; + + for (int clause_node_chunk = index; clause_node_chunk < (number_of_clauses)*(number_of_node_chunks); clause_node_chunk += stride) { + int clause = clause_node_chunk / number_of_node_chunks; + int node_chunk = clause_node_chunk % number_of_node_chunks; + + unsigned int *ta_state = &global_ta_state[clause*number_of_ta_chunks*number_of_state_bits]; + + clause_node_output = ~0; + for (int node_pos = 0; (node_pos < 32) && ((node_chunk * 32 + node_pos) < NUMBER_OF_PATCHES); ++node_pos) { + int node = node_chunk * 32 + node_pos; + + for (int la_chunk = 0; la_chunk < number_of_ta_chunks-1; ++la_chunk) { + if ((ta_state[la_chunk*number_of_state_bits + number_of_state_bits - 1] & (X[node*number_of_ta_chunks + la_chunk] | (!literal_active[la_chunk]))) != ta_state[la_chunk*number_of_state_bits + number_of_state_bits - 1]) { + clause_node_output &= ~(1 << node_pos); + } + } + + if ((ta_state[(number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits - 1] & (X[node*number_of_ta_chunks + number_of_ta_chunks-1] | (!literal_active[la_chunk]))) & filter) != (ta_state[(number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits - 1] & filter)) { + clause_node_output &= ~(1 << node_pos); + } + } + + if (node_chunk == number_of_node_chunks - 1) { + global_clause_node_output[clause*number_of_node_chunks + node_chunk] = clause_node_output & node_filter; + } else { + global_clause_node_output[clause*number_of_node_chunks + node_chunk] = clause_node_output; + } + } + } +} \ No newline at end of file