diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt b/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt index 0288c8a2f322..d94ba4b52001 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt @@ -118,7 +118,10 @@ foreach(_amd_chip ${_ukernel_supported_chips}) ROCM_ARCH ${_amd_chip} SRCS - "argmax_ukernel.c" + "iree_uk_amdgpu_argmax_f16i32.c" + "iree_uk_amdgpu_argmax_f16i64.c" + "iree_uk_amdgpu_argmax_f32i32.c" + "iree_uk_amdgpu_argmax_f32i64.c" ) endforeach() @@ -145,6 +148,10 @@ endforeach() # Generate a custom target with all file level dependencies and commands to # copy to our build tree locations. # Our GenDeviceLibs target depends on all of the defined device lib targets. +message(STATUS "_all_ukernel_bc_files=${_all_ukernel_bc_files}") +message(STATUS "_amd_ukernel_targets=${_amd_ukernel_targets}") +message(STATUS "_all_ukernel_bc_copy_commands=${_all_ukernel_bc_copy_commands}") + add_custom_command( OUTPUT ${_all_ukernel_bc_files} DEPENDS ${_amd_ukernel_targets} diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c b/compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c deleted file mode 100644 index 70021a215a02..000000000000 --- a/compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h" - -/* -Constraint/Tiling note: -For simplicity, we distribute all parallel dim across different workgroup, and -only use single subgroup/warp per workgroup. This constraint is also set during -tiling phase in KernelConfig. -*/ - -void __iree_uk_rocm_argmax_F32I32(const float *inputBuffer, - int64_t input_offset, int32_t *outputBuffer, - int64_t output_offset, - int64_t reductionSize) { - const int warpSize = __builtin_amdgcn_wavefrontsize(); - int32_t laneID = __builtin_amdgcn_workitem_id_x(); - // Set identity value to handle problem non divisible by subgroupSize. - float laneMax = - laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID]; - int32_t laneResult = laneID; - - // NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical - // inaccuracy. - int32_t numBatches = (reductionSize + warpSize - 1) / warpSize; - for (int i = 1; i < numBatches; ++i) { - int32_t idx = warpSize * i + laneID; - float newIn = - idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) - continue; - laneMax = __builtin_fmaxf(newIn, laneMax); - laneResult = newIn == laneMax ? idx : laneResult; - } - - // Final reduction with one subgroup - // NOTE: __ockl_wfred_max_f32 has correctness issue on gfx1100 documented on - // https://github.com/iree-org/iree/issues/16112. - float wgMax = laneMax; - for (int i = 1; i < warpSize; i *= 2) { - wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax); - } - // Check if there are multiple max value holders. - uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax); - // if there is only one max value holder, write and exit. - if (__builtin_popcountll(laneHasMaxValmask) == 1) { - if (wgMax == laneMax) - outputBuffer[output_offset] = laneResult; - return; - } - // if there are multiple max value holder, find smallest index (argmax - // semantics). - int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__; - laneResult = __ockl_wfred_min_i32(indexVal); - if (laneID == 0) - outputBuffer[output_offset] = laneResult; -} - -void __iree_uk_rocm_argmax_F32I64(const float *inputBuffer, - int64_t input_offset, int64_t *outputBuffer, - int64_t output_offset, - int64_t reductionSize) { - const int warpSize = __builtin_amdgcn_wavefrontsize(); - int32_t laneID = __builtin_amdgcn_workitem_id_x(); - // Set identity value to handle problem non divisible by subgroupSize. - float laneMax = - laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID]; - int64_t laneResult = laneID; - - // NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical - // inaccuracy. - int32_t numBatches = (reductionSize + warpSize - 1) / warpSize; - for (int i = 1; i < numBatches; ++i) { - int32_t idx = warpSize * i + laneID; - float newIn = - idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) - continue; - laneMax = __builtin_fmaxf(newIn, laneMax); - laneResult = newIn == laneMax ? idx : laneResult; - } - - // Final reduction with one subgroup - // NOTE: __ockl_wfred_max_f32 has correctness issue on gfx1100 documented on - // https://github.com/iree-org/iree/issues/16112. - float wgMax = laneMax; - for (int i = 1; i < warpSize; i *= 2) { - wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax); - } - // Check if there are multiple max value holders. - uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax); - // if there is only one max value holder, write and exit. - if (__builtin_popcountll(laneHasMaxValmask) == 1) { - if (wgMax == laneMax) - outputBuffer[output_offset] = laneResult; - return; - } - // if there are multiple max value holder, find smallest index (argmax - // semantics). - int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX; - laneResult = __ockl_wfred_min_i64(indexVal); - if (laneID == 0) - outputBuffer[output_offset] = laneResult; -} - -void __iree_uk_rocm_argmax_F16I32(const _Float16 *inputBuffer, - int64_t input_offset, int32_t *outputBuffer, - int64_t output_offset, - int64_t reductionSize) { - const int warpSize = __builtin_amdgcn_wavefrontsize(); - _Float16 NEG_F16_MAX = (_Float16)(-65504.0f); - int32_t laneID = __builtin_amdgcn_workitem_id_x(); - // Set identity value to handle problem non divisible by subgroupSize. - _Float16 laneMax = laneID >= reductionSize - ? NEG_F16_MAX - : inputBuffer[input_offset + laneID]; - int32_t laneResult = laneID; - - int32_t numBatches = (reductionSize + warpSize - 1) / warpSize; - for (int i = 1; i < numBatches; ++i) { - int32_t idx = warpSize * i + laneID; - _Float16 newIn = - idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) - continue; - laneMax = __builtin_fmaxf16(newIn, laneMax); - laneResult = newIn == laneMax ? idx : laneResult; - } - // Final reduction with one subgroup - _Float16 wgMax = __ockl_wfred_max_f16(laneMax); - // Check if there are multiple max value holders. - uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax); - // if there is only one max value holder, write and exit. - if (__builtin_popcountll(laneHasMaxValmask) == 1) { - if (wgMax == laneMax) - outputBuffer[output_offset] = laneResult; - return; - } - - // if there are multiple max value holder, find smallest index (argmax - // semantics). - int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__; - laneResult = __ockl_wfred_min_i32(indexVal); - if (laneID == 0) - outputBuffer[output_offset] = laneResult; -} - -void __iree_uk_rocm_argmax_F16I64(const _Float16 *inputBuffer, - int64_t input_offset, int64_t *outputBuffer, - int64_t output_offset, - int64_t reductionSize) { - const int warpSize = __builtin_amdgcn_wavefrontsize(); - _Float16 NEG_F16_MAX = (_Float16)(-65504.0f); - int32_t laneID = __builtin_amdgcn_workitem_id_x(); - // Set identity value to handle problem non divisible by subgroupSize. - _Float16 laneMax = laneID >= reductionSize - ? NEG_F16_MAX - : inputBuffer[input_offset + laneID]; - int64_t laneResult = laneID; - - int32_t numBatches = (reductionSize + warpSize - 1) / warpSize; - for (int i = 1; i < numBatches; ++i) { - int32_t idx = warpSize * i + laneID; - _Float16 newIn = - idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) - continue; - laneMax = __builtin_fmaxf16(newIn, laneMax); - laneResult = newIn == laneMax ? idx : laneResult; - } - - // Final reduction with one subgroup - _Float16 wgMax = __ockl_wfred_max_f16(laneMax); - // Check if there are multiple max value holders. - uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax); - // if there is only one max value holder, write and exit. - if (__builtin_popcountll(laneHasMaxValmask) == 1) { - if (wgMax == laneMax) - outputBuffer[output_offset] = laneResult; - return; - } - // if there are multiple max value holder, find smallest index (argmax - // semantics). - int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX; - laneResult = __ockl_wfred_min_i64(indexVal); - if (laneID == 0) - outputBuffer[output_offset] = laneResult; -} diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c new file mode 100644 index 000000000000..41fe50a6528d --- /dev/null +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c @@ -0,0 +1,49 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h" + +void iree_uk_amdgpu_argmax_f16i32(const _Float16 *inputBuffer, + int64_t input_offset, int32_t *outputBuffer, + int64_t output_offset, + int64_t reductionSize) { + const int warpSize = __builtin_amdgcn_wavefrontsize(); + _Float16 NEG_F16_MAX = (_Float16)(-65504.0f); + int32_t laneID = __builtin_amdgcn_workitem_id_x(); + // Set identity value to handle problem non divisible by subgroupSize. + _Float16 laneMax = laneID >= reductionSize + ? NEG_F16_MAX + : inputBuffer[input_offset + laneID]; + int32_t laneResult = laneID; + + int32_t numBatches = (reductionSize + warpSize - 1) / warpSize; + for (int i = 1; i < numBatches; ++i) { + int32_t idx = warpSize * i + laneID; + _Float16 newIn = + idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx]; + if (newIn == laneMax) + continue; + laneMax = __builtin_fmaxf16(newIn, laneMax); + laneResult = newIn == laneMax ? idx : laneResult; + } + // Final reduction with one subgroup + _Float16 wgMax = __ockl_wfred_max_f16(laneMax); + // Check if there are multiple max value holders. + uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax); + // if there is only one max value holder, write and exit. + if (__builtin_popcountll(laneHasMaxValmask) == 1) { + if (wgMax == laneMax) + outputBuffer[output_offset] = laneResult; + return; + } + + // if there are multiple max value holder, find smallest index (argmax + // semantics). + int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__; + laneResult = __ockl_wfred_min_i32(indexVal); + if (laneID == 0) + outputBuffer[output_offset] = laneResult; +} diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c new file mode 100644 index 000000000000..823fc3a4f296 --- /dev/null +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c @@ -0,0 +1,49 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h" + +void iree_uk_amdgpu_argmax_f16i64(const _Float16 *inputBuffer, + int64_t input_offset, int64_t *outputBuffer, + int64_t output_offset, + int64_t reductionSize) { + const int warpSize = __builtin_amdgcn_wavefrontsize(); + _Float16 NEG_F16_MAX = (_Float16)(-65504.0f); + int32_t laneID = __builtin_amdgcn_workitem_id_x(); + // Set identity value to handle problem non divisible by subgroupSize. + _Float16 laneMax = laneID >= reductionSize + ? NEG_F16_MAX + : inputBuffer[input_offset + laneID]; + int64_t laneResult = laneID; + + int32_t numBatches = (reductionSize + warpSize - 1) / warpSize; + for (int i = 1; i < numBatches; ++i) { + int32_t idx = warpSize * i + laneID; + _Float16 newIn = + idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx]; + if (newIn == laneMax) + continue; + laneMax = __builtin_fmaxf16(newIn, laneMax); + laneResult = newIn == laneMax ? idx : laneResult; + } + + // Final reduction with one subgroup + _Float16 wgMax = __ockl_wfred_max_f16(laneMax); + // Check if there are multiple max value holders. + uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax); + // if there is only one max value holder, write and exit. + if (__builtin_popcountll(laneHasMaxValmask) == 1) { + if (wgMax == laneMax) + outputBuffer[output_offset] = laneResult; + return; + } + // if there are multiple max value holder, find smallest index (argmax + // semantics). + int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX; + laneResult = __ockl_wfred_min_i64(indexVal); + if (laneID == 0) + outputBuffer[output_offset] = laneResult; +} diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c new file mode 100644 index 000000000000..41aad8ba05c5 --- /dev/null +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c @@ -0,0 +1,54 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h" + +void iree_uk_amdgpu_argmax_f32i32(const float *inputBuffer, + int64_t input_offset, int32_t *outputBuffer, + int64_t output_offset, + int64_t reductionSize) { + const int warpSize = __builtin_amdgcn_wavefrontsize(); + int32_t laneID = __builtin_amdgcn_workitem_id_x(); + // Set identity value to handle problem non divisible by subgroupSize. + float laneMax = + laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID]; + int32_t laneResult = laneID; + + // NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical + // inaccuracy. + int32_t numBatches = (reductionSize + warpSize - 1) / warpSize; + for (int i = 1; i < numBatches; ++i) { + int32_t idx = warpSize * i + laneID; + float newIn = + idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx]; + if (newIn == laneMax) + continue; + laneMax = __builtin_fmaxf(newIn, laneMax); + laneResult = newIn == laneMax ? idx : laneResult; + } + + // Final reduction with one subgroup + // NOTE: __ockl_wfred_max_f32 has correctness issue on gfx1100 documented on + // https://github.com/iree-org/iree/issues/16112. + float wgMax = laneMax; + for (int i = 1; i < warpSize; i *= 2) { + wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax); + } + // Check if there are multiple max value holders. + uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax); + // if there is only one max value holder, write and exit. + if (__builtin_popcountll(laneHasMaxValmask) == 1) { + if (wgMax == laneMax) + outputBuffer[output_offset] = laneResult; + return; + } + // if there are multiple max value holder, find smallest index (argmax + // semantics). + int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__; + laneResult = __ockl_wfred_min_i32(indexVal); + if (laneID == 0) + outputBuffer[output_offset] = laneResult; +} diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c new file mode 100644 index 000000000000..5899322d7407 --- /dev/null +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c @@ -0,0 +1,54 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h" + +void iree_uk_amdgpu_argmax_f32i64(const float *inputBuffer, + int64_t input_offset, int64_t *outputBuffer, + int64_t output_offset, + int64_t reductionSize) { + const int warpSize = __builtin_amdgcn_wavefrontsize(); + int32_t laneID = __builtin_amdgcn_workitem_id_x(); + // Set identity value to handle problem non divisible by subgroupSize. + float laneMax = + laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID]; + int64_t laneResult = laneID; + + // NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical + // inaccuracy. + int32_t numBatches = (reductionSize + warpSize - 1) / warpSize; + for (int i = 1; i < numBatches; ++i) { + int32_t idx = warpSize * i + laneID; + float newIn = + idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx]; + if (newIn == laneMax) + continue; + laneMax = __builtin_fmaxf(newIn, laneMax); + laneResult = newIn == laneMax ? idx : laneResult; + } + + // Final reduction with one subgroup + // NOTE: __ockl_wfred_max_f32 has correctness issue on gfx1100 documented on + // https://github.com/iree-org/iree/issues/16112. + float wgMax = laneMax; + for (int i = 1; i < warpSize; i *= 2) { + wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax); + } + // Check if there are multiple max value holders. + uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax); + // if there is only one max value holder, write and exit. + if (__builtin_popcountll(laneHasMaxValmask) == 1) { + if (wgMax == laneMax) + outputBuffer[output_offset] = laneResult; + return; + } + // if there are multiple max value holder, find smallest index (argmax + // semantics). + int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX; + laneResult = __ockl_wfred_min_i64(indexVal); + if (laneID == 0) + outputBuffer[output_offset] = laneResult; +} diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/test/argmax_linking.mlir b/compiler/plugins/target/ROCM/builtins/ukernel/test/argmax_linking.mlir index 5b4b416e5059..a848c2904744 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/test/argmax_linking.mlir +++ b/compiler/plugins/target/ROCM/builtins/ukernel/test/argmax_linking.mlir @@ -2,8 +2,8 @@ // We want to check that uKernel is indeed generated from e2e workflow. -// CHECK: llvm.func @__iree_uk_rocm_argmax_F32I64 -// CHECK: llvm.call @__iree_uk_rocm_argmax_F32I64 +// CHECK: llvm.func @iree_uk_amdgpu_argmax_f32i64 +// CHECK: llvm.call @iree_uk_amdgpu_argmax_f32i64 func.func @argmax_1d_f32i64(%arg0: tensor<1x?xf32>) -> tensor<1x1xi64> { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 @@ -29,8 +29,8 @@ func.func @argmax_1d_f32i64(%arg0: tensor<1x?xf32>) -> tensor<1x1xi64> { // ----- -// CHECK: llvm.func @__iree_uk_rocm_argmax_F16I64 -// CHECK: llvm.call @__iree_uk_rocm_argmax_F16I64 +// CHECK: llvm.func @iree_uk_amdgpu_argmax_f16i64 +// CHECK: llvm.call @iree_uk_amdgpu_argmax_f16i64 func.func @argmax_1d_f16i64(%arg0: tensor<1x?xf16>) -> tensor<1x1xi64> { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 @@ -57,8 +57,8 @@ func.func @argmax_1d_f16i64(%arg0: tensor<1x?xf16>) -> tensor<1x1xi64> { // ----- -// CHECK: llvm.func @__iree_uk_rocm_argmax_F32I64 -// CHECK: llvm.call @__iree_uk_rocm_argmax_F32I64 +// CHECK: llvm.func @iree_uk_amdgpu_argmax_f32i64 +// CHECK: llvm.call @iree_uk_amdgpu_argmax_f32i64 func.func @argmax_2d_f32i64(%arg0: tensor<16x?xf32>) -> tensor<16x1xi64> { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 @@ -84,8 +84,8 @@ func.func @argmax_2d_f32i64(%arg0: tensor<16x?xf32>) -> tensor<16x1xi64> { // ----- -// CHECK: llvm.func @__iree_uk_rocm_argmax_F32I32 -// CHECK: llvm.call @__iree_uk_rocm_argmax_F32I32 +// CHECK: llvm.func @iree_uk_amdgpu_argmax_f32i32 +// CHECK: llvm.call @iree_uk_amdgpu_argmax_f32i32 #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1)> func.func @argmax_3d_dyn_f32i32(%arg0: tensor) -> tensor { @@ -115,8 +115,8 @@ func.func @argmax_3d_dyn_f32i32(%arg0: tensor) -> tensor { // ----- -// CHECK: llvm.func @__iree_uk_rocm_argmax_F32I64 -// CHECK: llvm.call @__iree_uk_rocm_argmax_F32I64 +// CHECK: llvm.func @iree_uk_amdgpu_argmax_f32i64 +// CHECK: llvm.call @iree_uk_amdgpu_argmax_f32i64 func.func @argmax_3d_dyn_f32i64(%arg0: tensor) -> tensor { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 @@ -144,8 +144,8 @@ func.func @argmax_3d_dyn_f32i64(%arg0: tensor) -> tensor { // ----- -// CHECK: llvm.func @__iree_uk_rocm_argmax_F16I32 -// CHECK: llvm.call @__iree_uk_rocm_argmax_F16I32 +// CHECK: llvm.func @iree_uk_amdgpu_argmax_f16i32 +// CHECK: llvm.call @iree_uk_amdgpu_argmax_f16i32 func.func @argmax_3d_dyn_f16i32(%arg0: tensor) -> tensor { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 @@ -173,8 +173,8 @@ func.func @argmax_3d_dyn_f16i32(%arg0: tensor) -> tensor { // ----- -// CHECK: llvm.func @__iree_uk_rocm_argmax_F16I64 -// CHECK: llvm.call @__iree_uk_rocm_argmax_F16I64 +// CHECK: llvm.func @iree_uk_amdgpu_argmax_f16i64 +// CHECK: llvm.call @iree_uk_amdgpu_argmax_f16i64 func.func @argmax_3d_dyn_f16i64(%arg0: tensor) -> tensor { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp index 01a5935e1410..f76cdd1c6ccf 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp @@ -39,7 +39,7 @@ getFnNameAndDefAttrs(const char *ukernelName, std::string &typeSuffixID, FnNameAndDefAttrs result; if (isROCMBackend(targetAttr)) { result.name = - std::string("__iree_uk_rocm_") + ukernelName + "_" + typeSuffixID; + std::string("iree_uk_amdgpu_") + ukernelName + "_" + typeSuffixID; result.defAttrs.emplace_back(rewriter.getStringAttr("vm.import.module"), rewriter.getStringAttr("rocm")); } @@ -95,16 +95,12 @@ matchArgmaxDAGForUKernel(RewriterBase &rewriter, linalg::GenericOp op) { return failure(); } - std::string typeSuffixID = ""; - if (inputElemType.isF16() && indexElemType.isInteger(32)) { - typeSuffixID = "F16I32"; - } else if (inputElemType.isF16() && indexElemType.isInteger(64)) { - typeSuffixID = "F16I64"; - } else if (inputElemType.isF32() && indexElemType.isInteger(32)) { - typeSuffixID = "F32I32"; - } else if (inputElemType.isF32() && indexElemType.isInteger(64)) { - typeSuffixID = "F32I64"; - } else { + std::string typeSuffixID; + llvm::raw_string_ostream(typeSuffixID) << inputElemType << indexElemType; + // TODO(bjacob): this check won't be needed one this code will be updated to + // look up the table of contents of embedded bitcode files, one per symbol. + if (!(typeSuffixID == "f16i32" || typeSuffixID == "f16i64" || + typeSuffixID == "f32i32" || typeSuffixID == "f32i64")) { return rewriter.notifyMatchFailure( op, "unsupported combination of element types"); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir index d36d73f5ce03..cc71c379959f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir @@ -28,7 +28,7 @@ func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes // CHECK-DAG: %[[C1_index:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C0_i64:.+]] = arith.constant 0 // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]] -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "__iree_uk_rocm_argmax_F32I64" +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f32i64" // CHECK-SAME: ins(%[[ARG0]] : // CHECK-SAME: outs(%[[FILL]] : // CHECK: return %[[MICRO_KERNEL]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ukernel_pipeline_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ukernel_pipeline_transform.mlir index 7c9cfbc5dab9..af8ecf196115 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ukernel_pipeline_transform.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ukernel_pipeline_transform.mlir @@ -43,7 +43,7 @@ func.func @argmax_1d_f16i64() attributes {hal.executable.target = #executable_ta // CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @argmax_1d_f16i64() // CHECK-SAME: translation_info = #[[$TRANSLATION]] -// CHECK: iree_codegen.ukernel.generic "__iree_uk_rocm_argmax_F16I64" +// CHECK: iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f16i64" // ----- @@ -92,7 +92,7 @@ func.func @argmax_2d_f32i64() attributes {hal.executable.target = #executable_ta // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: %[[SUBVIEW:.*]] = memref.subview{{.*}} memref<16x?xf32 // CHECK-SAME: to memref<1x?xf32 -// CHECK: iree_codegen.ukernel.generic "__iree_uk_rocm_argmax_F32I64" ins(%[[SUBVIEW]] +// CHECK: iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f32i64" ins(%[[SUBVIEW]] // -----