Skip to content

Commit

Permalink
Split and rename AMDGPU ukernels (iree-org#19273)
Browse files Browse the repository at this point in the history
1. Change ukernels prefix from `__iree_uk_rocm` to `iree_uk_amdgpu`.
2. Change ukernels to lowercase.
3. Split ukernels into separate .c files, one .c file <-> one ukernel
function.

---------

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Nov 24, 2024
1 parent 5de0f06 commit 265570a
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 221 deletions.
9 changes: 8 additions & 1 deletion compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ foreach(_amd_chip ${_ukernel_supported_chips})
ROCM_ARCH
${_amd_chip}
SRCS
"argmax_ukernel.c"
"iree_uk_amdgpu_argmax_f16i32.c"
"iree_uk_amdgpu_argmax_f16i64.c"
"iree_uk_amdgpu_argmax_f32i32.c"
"iree_uk_amdgpu_argmax_f32i64.c"
)
endforeach()

Expand All @@ -145,6 +148,10 @@ endforeach()
# Generate a custom target with all file level dependencies and commands to
# copy to our build tree locations.
# Our GenDeviceLibs target depends on all of the defined device lib targets.
message(STATUS "_all_ukernel_bc_files=${_all_ukernel_bc_files}")
message(STATUS "_amd_ukernel_targets=${_amd_ukernel_targets}")
message(STATUS "_all_ukernel_bc_copy_commands=${_all_ukernel_bc_copy_commands}")

add_custom_command(
OUTPUT ${_all_ukernel_bc_files}
DEPENDS ${_amd_ukernel_targets}
Expand Down
192 changes: 0 additions & 192 deletions compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"

void iree_uk_amdgpu_argmax_f16i32(const _Float16 *inputBuffer,
int64_t input_offset, int32_t *outputBuffer,
int64_t output_offset,
int64_t reductionSize) {
const int warpSize = __builtin_amdgcn_wavefrontsize();
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
int32_t laneID = __builtin_amdgcn_workitem_id_x();
// Set identity value to handle problem non divisible by subgroupSize.
_Float16 laneMax = laneID >= reductionSize
? NEG_F16_MAX
: inputBuffer[input_offset + laneID];
int32_t laneResult = laneID;

int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
for (int i = 1; i < numBatches; ++i) {
int32_t idx = warpSize * i + laneID;
_Float16 newIn =
idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx];
if (newIn == laneMax)
continue;
laneMax = __builtin_fmaxf16(newIn, laneMax);
laneResult = newIn == laneMax ? idx : laneResult;
}
// Final reduction with one subgroup
_Float16 wgMax = __ockl_wfred_max_f16(laneMax);
// Check if there are multiple max value holders.
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
if (wgMax == laneMax)
outputBuffer[output_offset] = laneResult;
return;
}

// if there are multiple max value holder, find smallest index (argmax
// semantics).
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
laneResult = __ockl_wfred_min_i32(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"

void iree_uk_amdgpu_argmax_f16i64(const _Float16 *inputBuffer,
int64_t input_offset, int64_t *outputBuffer,
int64_t output_offset,
int64_t reductionSize) {
const int warpSize = __builtin_amdgcn_wavefrontsize();
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
int32_t laneID = __builtin_amdgcn_workitem_id_x();
// Set identity value to handle problem non divisible by subgroupSize.
_Float16 laneMax = laneID >= reductionSize
? NEG_F16_MAX
: inputBuffer[input_offset + laneID];
int64_t laneResult = laneID;

int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
for (int i = 1; i < numBatches; ++i) {
int32_t idx = warpSize * i + laneID;
_Float16 newIn =
idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx];
if (newIn == laneMax)
continue;
laneMax = __builtin_fmaxf16(newIn, laneMax);
laneResult = newIn == laneMax ? idx : laneResult;
}

// Final reduction with one subgroup
_Float16 wgMax = __ockl_wfred_max_f16(laneMax);
// Check if there are multiple max value holders.
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
if (wgMax == laneMax)
outputBuffer[output_offset] = laneResult;
return;
}
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"

void iree_uk_amdgpu_argmax_f32i32(const float *inputBuffer,
int64_t input_offset, int32_t *outputBuffer,
int64_t output_offset,
int64_t reductionSize) {
const int warpSize = __builtin_amdgcn_wavefrontsize();
int32_t laneID = __builtin_amdgcn_workitem_id_x();
// Set identity value to handle problem non divisible by subgroupSize.
float laneMax =
laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID];
int32_t laneResult = laneID;

// NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
// inaccuracy.
int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
for (int i = 1; i < numBatches; ++i) {
int32_t idx = warpSize * i + laneID;
float newIn =
idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx];
if (newIn == laneMax)
continue;
laneMax = __builtin_fmaxf(newIn, laneMax);
laneResult = newIn == laneMax ? idx : laneResult;
}

// Final reduction with one subgroup
// NOTE: __ockl_wfred_max_f32 has correctness issue on gfx1100 documented on
// https://github.com/iree-org/iree/issues/16112.
float wgMax = laneMax;
for (int i = 1; i < warpSize; i *= 2) {
wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax);
}
// Check if there are multiple max value holders.
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
if (wgMax == laneMax)
outputBuffer[output_offset] = laneResult;
return;
}
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
laneResult = __ockl_wfred_min_i32(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;
}
Loading

0 comments on commit 265570a

Please sign in to comment.