diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index 95fa496df..62a7dcb5e 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -287,6 +287,7 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; if (!use_shm) { my_shared = output_ptr; } int64_t embedding_table_idx = indices[output_idx]; + if (embedding_table_idx < 0) continue; EmbeddingT* emb_ptr = &embedding_dev_ref[embedding_desc.storage_offset + embedding_table_idx * embedding_stride]; @@ -429,6 +430,7 @@ __global__ void scatter_func_kernel(const InputT* input, } for (int e = 0; e < cur_idx_lines; e++) { int64_t embedding_table_idx = indices[input_idx + e]; + if (embedding_table_idx < 0) continue; EmbeddingT* emb_ptr = &embedding_dev_ref[embedding_desc.storage_offset + embedding_table_idx * embedding_stride];