Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
gather/scatter optimization; bug fixed -- abort gather when embedding…
Browse files Browse the repository at this point in the history
…_table_idx < 0
  • Loading branch information
linhu-nv committed Nov 15, 2023
1 parent d9e0286 commit e4768df
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx;
if (!use_shm) { my_shared = output_ptr; }
int64_t embedding_table_idx = indices[output_idx];
if (embedding_table_idx < 0) continue;
EmbeddingT* emb_ptr =
&embedding_dev_ref[embedding_desc.storage_offset + embedding_table_idx * embedding_stride];

Expand Down Expand Up @@ -429,6 +430,7 @@ __global__ void scatter_func_kernel(const InputT* input,
}
for (int e = 0; e < cur_idx_lines; e++) {
int64_t embedding_table_idx = indices[input_idx + e];
if (embedding_table_idx < 0) continue;
EmbeddingT* emb_ptr =
&embedding_dev_ref[embedding_desc.storage_offset + embedding_table_idx * embedding_stride];

Expand Down

0 comments on commit e4768df

Please sign in to comment.