Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
olegranmo committed Oct 8, 2024
1 parent 3dd02d4 commit 8b5ca6c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
5 changes: 3 additions & 2 deletions tmu/clause_bank/clause_bank_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ def calculate_clause_outputs_predict(self, encoded_X, e):
current_clause_node_output = self.current_clause_node_output_test_gpu
next_clause_node_output = self.current_clause_node_output_test_gpu

self.ta_state_gpu = cuda.mem_alloc(self.clause_bank.nbytes)
cuda.memcpy_htod(self.ta_state_gpu, self.clause_bank)
clause_bank = self.clause_bank.reshape(-1)
self.ta_state_gpu = cuda.mem_alloc(clause_bank.nbytes)
cuda.memcpy_htod(self.ta_state_gpu, clause_bank)

self.attention_gpu = cuda.mem_alloc(self.attention.nbytes)
cuda.memcpy_htod(self.attention_gpu, self.attention)
Expand Down
29 changes: 14 additions & 15 deletions tmu/clause_bank/cuda/calculate_clause_value_in_patch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
extern "C"
{
__global__ void calculate_clause_value_in_patch(
unsigned int *ta_state,
int number_of_clauses,
int number_of_literals,
int number_of_state_bits,
Expand Down Expand Up @@ -62,24 +61,24 @@ extern "C"

for (int clause_node_chunk = index; clause_node_chunk < (number_of_clauses)*(number_of_node_chunks); clause_node_chunk += stride) {
int clause = clause_node_chunk / number_of_node_chunks;
int node_chunk = clause_node_chunk % number_of_node_chunks;
/int node_chunk = clause_node_chunk % number_of_node_chunks;

unsigned int *ta_state = &global_ta_state[clause*number_of_ta_chunks*number_of_state_bits];
// unsigned int *ta_state = &global_ta_state[clause*number_of_ta_chunks*number_of_state_bits];

clause_node_output = ~0;
for (int node_pos = 0; (node_pos < 32) && ((node_chunk * 32 + node_pos) < NUMBER_OF_PATCHES); ++node_pos) {
int node = node_chunk * 32 + node_pos;
// clause_node_output = ~0;
// for (int node_pos = 0; (node_pos < 32) && ((node_chunk * 32 + node_pos) < NUMBER_OF_PATCHES); ++node_pos) {
// int node = node_chunk * 32 + node_pos;

for (int la_chunk = 0; la_chunk < number_of_ta_chunks-1; ++la_chunk) {
if ((ta_state[la_chunk*number_of_state_bits + number_of_state_bits - 1] & (X[node*number_of_ta_chunks + la_chunk] | (!literal_active[la_chunk]))) != ta_state[la_chunk*number_of_state_bits + number_of_state_bits - 1]) {
clause_node_output &= ~(1 << node_pos);
}
}
// for (int la_chunk = 0; la_chunk < number_of_ta_chunks-1; ++la_chunk) {
// if ((ta_state[la_chunk*number_of_state_bits + number_of_state_bits - 1] & (X[node*number_of_ta_chunks + la_chunk] | (!literal_active[la_chunk]))) != ta_state[la_chunk*number_of_state_bits + number_of_state_bits - 1]) {
// clause_node_output &= ~(1 << node_pos);
// }
// }

if ((ta_state[(number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits - 1] & (X[node*number_of_ta_chunks + number_of_ta_chunks-1] | (!literal_active[number_of_ta_chunks-1])) & filter) != (ta_state[(number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits - 1] & filter)) {
clause_node_output &= ~(1 << node_pos);
}
}
// if ((ta_state[(number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits - 1] & (X[node*number_of_ta_chunks + number_of_ta_chunks-1] | (!literal_active[number_of_ta_chunks-1])) & filter) != (ta_state[(number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits - 1] & filter)) {
// clause_node_output &= ~(1 << node_pos);
// }
// }

if (node_chunk == number_of_node_chunks - 1) {
global_clause_node_output[clause*number_of_node_chunks + node_chunk] = clause_node_output & node_filter;
Expand Down

0 comments on commit 8b5ca6c

Please sign in to comment.