From 5ea0b2102ecf402c7ed8903bd176dfdf23847a81 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 19 Jul 2024 08:59:38 -0700 Subject: [PATCH] [Codegen] Add interface tensor reshape foldings to TileAndDistribute (#17758) This PR adds reshape into interface tensor folding patterns to TileAndDistributeToWorkgroups. If there are reshapes between interface tensors and their users, then TileAndDistributeToWorkgroups can fail, so these patterns help to preprocess the input IR into a form that can be distributed. The patterns can create duplicate interface binding ops, so a CSE pass is added whenever calling TileAndDistributeToWorkgroups. --------- Signed-off-by: Max Dawkins --- .../TileAndDistributeToWorkgroupsPass.cpp | 9 +++ .../tile_and_distribute_to_workgroups.mlir | 65 +++++++++++++++++++ .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 1 + .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 1 + .../iree/compiler/Codegen/SPIRV/Passes.cpp | 1 + .../src/iree/compiler/Codegen/VMVX/Passes.cpp | 1 + 6 files changed, 78 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp index d9dee6a28e10..2ce82fa7d49a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp @@ -285,6 +285,15 @@ void TileAndDistributeToWorkgroupsPass::runOnOperation() { auto funcOp = getOperation(); + { + RewritePatternSet patterns(context); + populateReshapeToInterfaceTensorPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + funcOp.emitOpError("reshape to interface tensor patterns failed"); + return signalPassFailure(); + } + } + // TODO(MaheshRavishankar): The logic of lowering workgroup count // needs to be moved out of this pass. Once this is moved to // use scf.forall, this logic can be moved to the scf.forall diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir index 02c68865681f..d5738c8bc4b0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir @@ -2663,3 +2663,68 @@ hal.executable private @set_size_to_tilesize_when_divisible { // NO-LOOP: %[[RESULT:.+]] = linalg.generic // NO-LOOP: -> tensor<1x16x128xf16> // NO-LOOP: flow.dispatch.tensor.store %[[RESULT]], %{{.+}}, offsets = [%[[IDX_Y]], 0, %[[OFFX]]] + +// ----- + +#config = #iree_codegen.lowering_config +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "system-elf-x86_64"> +#translation = #iree_codegen.translation_info +hal.executable private @reshape_matmul_tensors { + hal.executable.variant public @system_elf_x86_64 target(#executable_target_system_elf_x86_64_) { + hal.executable.export public @reshape_matmul layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @reshape_matmul() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) + : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) + : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) + : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [64, 2, 256], strides = [1, 1, 1] + : !flow.dispatch.tensor> -> tensor<64x2x256xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 512], strides = [1, 1] + : !flow.dispatch.tensor> -> tensor<256x512xf32> + %collapsed = tensor.collapse_shape %3 [[0, 1], [2]] : tensor<64x2x256xf32> into tensor<128x256xf32> + %5 = tensor.empty() : tensor<128x512xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x512xf32>) -> tensor<128x512xf32> + %7 = linalg.matmul {lowering_config = #config} + ins(%collapsed, %4 : tensor<128x256xf32>, tensor<256x512xf32>) outs(%6 : tensor<128x512xf32>) -> tensor<128x512xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [128, 512], strides = [1, 1] + : tensor<128x512xf32> -> !flow.dispatch.tensor> + return + } + } + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK: hal.executable.export public @reshape_matmul +// CHECK-NEXT: (%[[DEVICE:.+]]: !hal.device) +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK: hal.return %[[C32]], %[[C4]], %[[C1]] +// CHECK: func.func @reshape_matmul() +// CHECK: scf.for %[[IV0:.+]] = +// CHECK: scf.for %[[IV1:.+]] = +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %{{.+}}, offsets = [%[[IV0]], 0], sizes = [32, 256] +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %{{.+}}, offsets = [0, %[[IV1]]], sizes = [256, 16] +// CHECK-DAG: %[[INIT:.+]] = tensor.empty +// CHECK-DAG: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT]] : +// CHECK-DAG: %[[GEMM:.+]] = linalg.matmul +// CHECK-SAME: outs(%[[FILL]] : +// CHECK: flow.dispatch.tensor.store %[[GEMM]] +// CHECK-SAME: offsets = [%[[IV0]], %[[IV1]]], sizes = [32, 16] diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index c99c3d156e77..f7db82676104 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -106,6 +106,7 @@ static llvm::cl::opt clForceArmStreaming( static void addTileAndDistributePasses(OpPassManager &funcPassManager) { funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass()); + funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createConvertToDestinationPassingStylePass()); funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass()); funcPassManager.addPass(createCanonicalizerPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index d722334795d0..153896714bb0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -222,6 +222,7 @@ tileAndDistributeToWorkgroup(OpPassManager &funcPassManager, funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass( kNumMaxParallelDims, linalg::DistributionMethod::CyclicNumProcsEqNumIters)); + funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createConvertToDestinationPassingStylePass( useWARForCooperativeMatrixCodegen)); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index fd2b021ac6ae..d669b1772306 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -112,6 +112,7 @@ static void addTileAndDistributeToWorkgroupsPasses( funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass( kNumMaxParallelDims, linalg::DistributionMethod::CyclicNumProcsEqNumIters)); + funcPassManager.addPass(createCSEPass()); if (useFuseTensorPadWithConsumerPass) { funcPassManager.addPass(createFuseTensorPadWithConsumerPass()); } diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp index 48bc5dbe90eb..a18598328aa6 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp @@ -37,6 +37,7 @@ static llvm::cl::opt clEnableUKernelsDecomposeLinalgGeneric( static void addTileAndDistributePasses(OpPassManager &funcPassManager) { funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass()); + funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createConvertToDestinationPassingStylePass()); funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass()); funcPassManager.addPass(createCanonicalizerPass());