Skip to content

Commit

Permalink
Merge pull request #494 from sony/feature/20240306-top-k-zero-value
Browse files Browse the repository at this point in the history
fix atomicInc called too many times causes rewind
  • Loading branch information
YukioOobuchi authored Apr 24, 2024
2 parents d9a65c6 + 95ce574 commit 8e5520b
Showing 1 changed file with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions include/nbla/cuda/utils/top_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ template <typename T, bool Largest> struct ValIdxBitonic {
compare(const ValIdxBitonic<T, Largest> &a,
const ValIdxBitonic<T, Largest> &b) {
TopKGreater<Largest> greater;
TopKLessEqual<Largest> less_equal;
if (a.body.v == b.body.v) {
return less_equal(a.body.i, b.body.i);
}
return greater(a.body.v, b.body.v);
}

Expand All @@ -221,44 +225,70 @@ template <typename T, bool Largest> struct ValIdxBitonic {
count is used as atomic counter and must initially be set to zero.
*/
template <typename T, bool UseAbsVal, bool Largest>
__global__ void init_val_idx_list(const T *data, const int size,
Bucket<T> *bucket, ValIdx<T> *sort_data,
const unsigned int sort_data_size,
unsigned int K, unsigned int *k) {
TopKGreaterEqual<Largest> greater_equal;
__global__ void
init_val_idx_list_greater(const T *data, const int size, Bucket<T> *bucket,
ValIdx<T> *sort_data,
const unsigned int sort_data_size, unsigned int K) {
TopKGreater<Largest> greater;

const auto thread = blockIdx.x * blockDim.x + threadIdx.x;
const auto stride = blockDim.x * gridDim.x;

for (unsigned int index = thread; index < size; index += stride) {
T value = UseAbsVal ? abs(data[index]) : data[index];
if (greater_equal(value, bucket->pivot)) {
if (greater(value, bucket->pivot)) {
sort_data[atomicInc(&bucket->count, sort_data_size)] = {value, index};
}
}
*k = bucket->count;
if (*k < K) {
*k = K;
}

template <typename T, bool UseAbsVal, bool Largest>
__global__ void
init_val_idx_list_equal(const T *data, const int size, Bucket<T> *bucket_data,
ValIdx<T> *sort_data, const unsigned int sort_data_size,
unsigned int K) {
auto base = bucket_data[0].count;
auto bucket = &bucket_data[1];

const auto thread = blockIdx.x * blockDim.x + threadIdx.x;
const auto stride = blockDim.x * gridDim.x;

if (base <= K - 1) {
for (unsigned int index = thread; index < size; index += stride) {
T value = UseAbsVal ? abs(data[index]) : data[index];
if (value == bucket_data[0].pivot) {
sort_data[base + atomicInc(&bucket->count, sort_data_size - base)] = {
value, index};
}
}
}
}

template <typename T>
__global__ void summary_valid_k(Bucket<T> *bucket_data, unsigned int K) {
bucket_data[0].count += bucket_data[1].count;
assert(bucket_data[0].count >= K);
}

const unsigned int MAX_K = 1024;

template <typename T, bool UseAbsVal, bool Largest>
__host__ void find_top_k_index(const T *data, const int size, Bucket<T> *bucket,
ValIdx<T> *sort_data, unsigned int K,
unsigned int *valid_k) {
ValIdx<T> *sort_data, unsigned int K) {
auto threads = NBLA_CUDA_NUM_THREADS;
auto blocks = NBLA_CUDA_GET_BLOCKS(size);

init_val_idx_list<T, UseAbsVal, Largest>
<<<blocks, threads>>>(data, size, bucket, sort_data, MAX_K, K, valid_k);
init_val_idx_list_greater<T, UseAbsVal, Largest>
<<<blocks, threads>>>(data, size, bucket, sort_data, MAX_K, K);
init_val_idx_list_equal<T, UseAbsVal, Largest>
<<<blocks, threads>>>(data, size, bucket, sort_data, MAX_K, K);
summary_valid_k<<<1, 1>>>(bucket, K);
NBLA_CUDA_KERNEL_CHECK();

// The memory layout of ValIdxBitonic is exactly the same as ValIdx.
auto actual_sort_data =
reinterpret_cast<ValIdxBitonic<T, Largest> *>(sort_data);
bitonic_sort<<<1, MAX_K>>>(actual_sort_data, valid_k);
bitonic_sort<<<1, MAX_K>>>(actual_sort_data, &bucket->count);

NBLA_CUDA_KERNEL_CHECK();
}
Expand All @@ -267,7 +297,6 @@ template <typename T> struct Buffer {
MinMax<T> minmax[CUDA_WARP_SIZE];
Bucket<T> bucket[CUDA_WARP_SIZE];
ValIdx<T> sorted[MAX_K];
unsigned int valid_k; // Used for transferring the valid K number
};

template <typename T, bool UseAbsVal, bool Largest>
Expand All @@ -276,8 +305,8 @@ __host__ void top_k_body(const T *data, const unsigned int size,
minmax<T, UseAbsVal, true>(data, size, &buffer->minmax[0]);
find_top_k_value<T, UseAbsVal, Largest>(data, size, &buffer->minmax[0],
&buffer->bucket[0], K);
find_top_k_index<T, UseAbsVal, Largest>(
data, size, &buffer->bucket[0], &buffer->sorted[0], K, &buffer->valid_k);
find_top_k_index<T, UseAbsVal, Largest>(data, size, &buffer->bucket[0],
&buffer->sorted[0], K);
}

template <typename T, bool UseAbsVal = false>
Expand Down

0 comments on commit 8e5520b

Please sign in to comment.