Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
olegranmo committed Oct 8, 2024
1 parent 718ee26 commit a0177e5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
29 changes: 14 additions & 15 deletions tmu/clause_bank/clause_bank_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(

mod = load_cuda_kernel(parameters, "cuda/calculate_clause_value_in_patch.cu")
self.clause_value_in_patch_gpu = mod.get_function("calculate_clause_value_in_patch")
self.clause_value_in_patch_gpu.prepare("PiiiPPPPPi")
self.clause_value_in_patch_gpu.prepare("PiiiPPPPP")

self.clause_output = np.empty(self.number_of_clauses, dtype=np.uint32, order="c")
self.clause_and_target = np.zeros(self.number_of_clauses * self.number_of_ta_chunks, dtype=np.uint32, order="c")
Expand Down Expand Up @@ -263,20 +263,19 @@ def calculate_clause_outputs_predict(self, encoded_X, e):
self.encoded_X_gpu = cuda.mem_alloc(encoded_X[e, :].nbytes)
cuda.memcpy_htod(self.encoded_X_gpu, encoded_X[e, :])

# self.calculate_clause_value_in_patch.prepared_call(
# self.grid,
# self.block,
# self.number_of_clauses,
# self.number_of_features,
# self.number_of_state_bits_ta,
# self.ta_state_gpu,
# current_clause_node_output,
# next_clause_node_output,
# self.attention_gpu,
# encoded_X_gpu,
# e
# )
# cuda.Context.synchronize()
self.calculate_clause_value_in_patch.prepared_call(
self.grid,
self.block,
self.number_of_clauses,
self.number_of_features,
self.number_of_state_bits_ta,
self.ta_state_gpu,
current_clause_node_output,
next_clause_node_output,
self.attention_gpu,
encoded_X_gpu
)
cuda.Context.synchronize()

lib.cb_calculate_spatio_temporal_features(
self.ptr_ta_state,
Expand Down
5 changes: 1 addition & 4 deletions tmu/clause_bank/cuda/calculate_clause_value_in_patch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ extern "C"
int *global_clause_node_output,
int *global_clause_node_output_next,
unsigned int *literal_active,
unsigned int *global_X,
int e
unsigned int *X
)
{
int index = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -61,8 +60,6 @@ extern "C"
node_filter = 0xffffffff;
}

unsigned int *X = &global_X[e * number_of_ta_chunks * NUMBER_OF_PATCHES];

for (int clause_node_chunk = index; clause_node_chunk < (number_of_clauses)*(number_of_node_chunks); clause_node_chunk += stride) {
int clause = clause_node_chunk / number_of_node_chunks;
int node_chunk = clause_node_chunk % number_of_node_chunks;
Expand Down

0 comments on commit a0177e5

Please sign in to comment.