From ba9ea8590cd01a3cfaf058a0000da47c052a1a7a Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:29:26 -0700 Subject: [PATCH] [LLVMGPU] Add im2col pipeline for convolution codegen (#18086) This PR adds the remaining needed passes for the IGEMM pipeline using the im2col op. It adds the `Conv2DToIm2colOp` pass with a flag `--iree-codegen-llvmgpu-use-igemm`, and it adds the im2col decomposition pass before vectorization passes. `--iree-codegen-llvmgpu-use-igemm` will be false by default until the IGEMM pipeline is more robust and performant. --------- Signed-off-by: Max Dawkins --- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 8 ++ .../test/ROCDL/pipeline_tile_and_fuse.mlir | 90 +++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index f2a975965c4a..47d137e3d344 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -72,6 +72,11 @@ static llvm::cl::opt clLLVMGPUSharedMemoryLimit( "allocated for the given target"), llvm::cl::init(163 * 1024)); +static llvm::cl::opt + clLLVMGPUUseIgemm("iree-codegen-llvmgpu-use-igemm", + llvm::cl::desc("Enable implicit gemm for convolutions."), + llvm::cl::init(false)); + llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const LLVMGPUPipelineOptions &options) { StringRef reorderStr = ""; @@ -241,6 +246,7 @@ static void tileAndBufferize(OpPassManager &funcPassManager) { static void addGPUVectorizationPasses(OpPassManager &funcPassManager) { funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass()); + funcPassManager.addPass(IREE::LinalgExt::createDecomposeIm2colPass()); // Vectorize. GenericVectorizationPassOptions options; options.vectorizePadding = true; @@ -1043,6 +1049,8 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl( OpPassManager &modulePassManager) { { FunctionLikeNest funcPassManager(modulePassManager); + funcPassManager.addPredicatedPass( + clLLVMGPUUseIgemm, IREE::LinalgExt::createConvertConv2DToIm2ColOpPass); funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); addEncodingToNopPasses(funcPassManager); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 95463a872aa2..165169791229 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -126,3 +126,93 @@ hal.executable public @main { // CHECK: scf.yield %[[MM]] // CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32> // CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer, ReadOnly>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 64, 0], reduction = [0, 0, 0, 2], subgroup = [1, 2, 2], mma_kind = #iree_gpu.mma_layout}> +hal.executable private @main { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @conv_igemm_im2col ordinal(0) layout(#pipeline_layout) + attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @conv_igemm_im2col() attributes {translation_info = #iree_codegen.translation_info} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 34, 34, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x34x34x1280xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1280, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<3x3x1280x1280xf16> + %5 = tensor.empty() : tensor<2x16x16x1280xf32> + %6 = tensor.empty() : tensor<2x256x11520xf16> + %7 = iree_linalg_ext.im2col + strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3] + m_offset = [0] k_offset = [0] + batch_pos = [0] m_pos = [1, 2] k_pos = [3] + ins(%3 : tensor<2x34x34x1280xf16>) + outs(%6 : tensor<2x256x11520xf16>) -> tensor<2x256x11520xf16> + %collapsed = tensor.collapse_shape %4 [[0, 1, 2], [3]] : tensor<3x3x1280x1280xf16> into tensor<11520x1280xf16> + %collapsed_0 = tensor.collapse_shape %5 [[0], [1, 2], [3]] : tensor<2x16x16x1280xf32> into tensor<2x256x1280xf32> + %8 = linalg.fill ins(%cst : f32) outs(%collapsed_0 : tensor<2x256x1280xf32>) -> tensor<2x256x1280xf32> + %9 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%7, %collapsed : tensor<2x256x11520xf16>, tensor<11520x1280xf16>) + outs(%8 : tensor<2x256x1280xf32>) attrs = {lowering_config = #config} { + ^bb0(%in: f16, %in_1: f16, %out: f32): + %10 = arith.extf %in : f16 to f32 + %11 = arith.extf %in_1 : f16 to f32 + %12 = arith.mulf %10, %11 : f32 + %13 = arith.addf %12, %out : f32 + linalg.yield %13 : f32 + } -> tensor<2x256x1280xf32> + %expanded = tensor.expand_shape %9 [[0], [1, 2], [3]] output_shape [2, 16, 16, 1280] : tensor<2x256x1280xf32> into tensor<2x16x16x1280xf32> + flow.dispatch.tensor.store %expanded, %2, offsets = [0, 0, 0, 0], sizes = [2, 16, 16, 1280], strides = [1, 1, 1, 1] : tensor<2x16x16x1280xf32> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-LABEL: func @conv_igemm_im2col +// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan set(0) binding(0) +// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan set(0) binding(1) +// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan set(0) binding(2) +// CHECK-DAG: memref.alloc() : memref<1x64x32xf16, #gpu.address_space> +// CHECK-DAG: memref.alloc() : memref<32x64xf16, #gpu.address_space> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>) +// CHECK: gpu.barrier +// CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK: vector.transfer_write %[[LHS_RD]] +// CHECK: gpu.barrier +// CHECK: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> +// CHECK: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16> +// CHECK: gpu.barrier +// CHECK: %[[LHS_T:.+]] = vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16> +// CHECK: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK: vector.transfer_write %[[RHS_RD]] +// CHECK: gpu.barrier +// CHECK: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16> +// CHECK: gpu.barrier +// CHECK: %[[RHS_T:.+]] = vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16> +// CHECK: %[[MM:.+]] = iree_gpu.multi_mma %[[LHS_T]], %[[RHS_T]] +// CHECK: scf.yield %[[MM]] +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32> +// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]]