From 4a225330f8f43386d9865a8a9e280ab6a1e43e3c Mon Sep 17 00:00:00 2001 From: Li-Wen Chang <120213201+liwenchangbdbz@users.noreply.github.com> Date: Mon, 25 Sep 2023 23:19:18 -0700 Subject: [PATCH] [Release] Official Release ByteIR 1.4.0 (#71) * [Sync] internal a6ef5f00...22d6dee6 * [AIT] Disabled hw info log, added error handling * [CAT] Adjusted layout support * [compiler/doc] Added codegen doc * [frontend/torch] Added demo code, added more fx pattern for llm, fixed einsum, updated to 23b72244b1e1eaa0511cece8535b32810c1d0d7a * [GPU] Added reduction codegen for PTX * [Mhlo] Fixed transpose movedown, Added canonicalizer for gather with iota * [Release] released 1.4.0 package * [Runtime] Supported non-splat value in FillOp, added dropout support for flashV2 * [Util] Fixed bugs --- compiler/doc/codegen.md | 245 +++++ compiler/include/byteir/Analysis/UseRange.h | 15 + .../byteir/Conversion/FuncToByre/FuncToByre.h | 6 + compiler/include/byteir/Conversion/Passes.td | 26 + .../byteir/Conversion/ToLinalg/ToLinalg.h | 3 +- .../include/byteir/Conversion/ToPTX/ToPTX.h | 3 +- .../include/byteir/Dialect/CMakeLists.txt | 2 + .../include/byteir/Dialect/Cat/IR/CatOps.td | 11 + .../include/byteir/Dialect/GPU/CMakeLists.txt | 3 + compiler/include/byteir/Dialect/GPU/Passes.h | 37 + compiler/include/byteir/Dialect/GPU/Passes.td | 36 + .../Dialect/GPU/Transforms/Transforms.h | 32 + .../include/byteir/Dialect/Linalg/Passes.td | 10 + .../TransformOps/LinalgExtTransformOps.td | 39 + .../Linalg/Transforms/LinalgCollapseLoops.h | 7 +- .../byteir/Dialect/Tensor/CMakeLists.txt | 3 + .../include/byteir/Dialect/Tensor/Passes.h | 31 + .../include/byteir/Dialect/Tensor/Passes.td | 35 + .../Transforms/TensorPadSpecialization.h | 30 + .../byteir/Dialect/Transform/Passes.td | 36 + .../Transform/Transforms/TransformInsertion.h | 10 + .../Dialect/mhlo/Transforms/CanonicalizeExt.h | 3 + .../mhlo/Transforms/ConvertOpToCustomCall.h | 4 +- .../mhlo/Transforms/GenericFusionCommon.h | 4 + .../byteir/Dialect/mhlo/Transforms/HloFuser.h | 6 + .../byteir/Dialect/mhlo/Util/CustomCallUtil.h | 8 + .../byteir/Pipelines/GPU/MappingForall.h | 53 + .../byteir/Pipelines/GPU/ReductionCodegen.h | 153 +++ .../byteir/Pipelines/InitAllPipelines.h | 4 + .../byteir/Transforms/MemoryPlanning.h | 7 +- compiler/include/byteir/Transforms/Passes.td | 2 +- compiler/lib/Analysis/UseRange.cpp | 7 +- compiler/lib/CAPI/CMakeLists.txt | 1 + compiler/lib/CAPI/Passes.cpp | 4 + .../lib/Conversion/FuncToByre/FuncToByre.cpp | 78 ++ .../lib/Conversion/HloToCat/FuseHloToCat.cpp | 45 +- .../HloToCat/FuseHloToCatPattern.td | 30 +- .../Conversion/MemrefToByre/MemrefToByre.cpp | 25 +- compiler/lib/Conversion/ToByre/ToByre.cpp | 15 +- .../ToLinalg/MemrefCopyToLinalg.cpp | 140 +-- .../lib/Conversion/ToPTX/CollectGPUKernel.cpp | 33 +- compiler/lib/Dialect/CMakeLists.txt | 1 + compiler/lib/Dialect/Cat/IR/CatDialect.cpp | 6 + compiler/lib/Dialect/GPU/CMakeLists.txt | 1 + .../lib/Dialect/GPU/Transforms/CMakeLists.txt | 19 + .../Transforms/ShmAllocaToWorkgroupArg.cpp | 86 ++ .../TransformOps/LinalgExtTransformOps.cpp | 132 +++ .../Linalg/Transforms/FuseElementwise.cpp | 25 +- .../Linalg/Transforms/LinalgCollapseLoops.cpp | 45 +- .../Dialect/Tensor/Transforms/CMakeLists.txt | 3 + .../Tensor/Transforms/CanonicalizeExt.cpp | 42 + .../Dialect/Tensor/Transforms/PassDetail.h | 40 + .../Transforms/TensorPadSpecialization.cpp | 242 +++++ .../Transforms/TransformInsertion.cpp | 101 ++ .../mhlo/Transforms/CanonicalizeExt.cpp | 55 + .../lib/Dialect/mhlo/Transforms/CatFusion.cpp | 16 +- .../mhlo/Transforms/ConvertOpToCustomCall.cpp | 91 +- .../Dialect/mhlo/Transforms/GenericFusion.cpp | 104 +- .../mhlo/Transforms/HloAggressiveFusion.cpp | 5 +- .../Dialect/mhlo/Transforms/HloMoveDown.cpp | 54 +- compiler/lib/Pipelines/BufferizeOpt.cpp | 2 + compiler/lib/Pipelines/ByreOpt.cpp | 5 +- compiler/lib/Pipelines/GPU/CMakeLists.txt | 4 + compiler/lib/Pipelines/GPU/GPUOpt.cpp | 39 +- compiler/lib/Pipelines/GPU/MappingForall.cpp | 148 +++ compiler/lib/Pipelines/GPU/NVVMCodegen.cpp | 4 + .../lib/Pipelines/GPU/ReductionCodegen.cpp | 942 ++++++++++++++++++ compiler/lib/Pipelines/HloOpt.cpp | 3 + compiler/lib/Pipelines/LinalgMemrefOpt.cpp | 4 +- compiler/lib/Pipelines/LinalgTensorOpt.cpp | 138 ++- compiler/lib/Transforms/Bufferize.cpp | 238 ++++- compiler/lib/Transforms/MemoryPlanning.cpp | 47 +- compiler/lib/Utils/Utils.cpp | 11 +- compiler/numerical/hlo/canonicalize_ext.mlir | 37 + compiler/numerical/hlo/hlo_move_down.mlir | 22 + compiler/python/ByteIRModules.cpp | 2 + compiler/python/byteir/compile.py | 20 +- .../byteir/dialects/cat/ir_processor.py | 30 +- compiler/python/byteir/tools/compiler.py | 4 +- .../test/Conversion/HloToCat/fused_ops.mlir | 66 +- .../ToByre/convertMemRefToByre.mlir | 22 +- .../transform-op-fold-unit-extent-dims.mlir | 6 +- .../transforms/ConvertOpToCustomCall.mlir | 19 + .../Dialect/Mhlo/transforms/hloMoveDown.mlir | 26 +- .../test/Dialect/Tensor/canonicalizeExt.mlir | 11 + compiler/test/Transforms/canonicalizeExt.mlir | 37 + compiler/test/Transforms/memoryPlanning.mlir | 12 +- compiler/tools/byteir-opt/CMakeLists.txt | 1 + compiler/tools/byteir-opt/byteir-opt.cpp | 4 + external/patches/AITemplate/logging.patch | 17 + .../src/Conversion/OFRewriteToCustomCall.cpp | 1 + .../test/of_rewrite_to_custom_call.mlir | 2 +- .../torch-frontend/examples/demo/README.md | 16 + .../torch-frontend/examples/demo/backend.py | 195 ++++ .../examples/demo/byteir_fusible_pattern.py | 194 ++++ .../examples/demo/compile_utils.py | 92 ++ .../torch-frontend/examples/demo/config.py | 35 + .../examples/demo/fx_match_utils.py | 40 + .../torch-frontend/examples/demo/main.py | 220 ++++ .../examples/demo/partitioners.py | 940 +++++++++++++++++ .../third_party/patches/einsum.patch | 633 +++++++----- .../Conversion/ConvertTorchToCustomCall.cpp | 4 + .../python/test/test_attn_rewrite.py | 18 + .../python/test/test_fx_utils.py | 18 + .../python/torch_frontend/__init__.py | 3 +- .../python/torch_frontend/fx_utils.py | 86 ++ .../include/brt/core/framework/op_accessor.h | 3 + .../cuda/providers/default/ait/ait.cc | 3 +- .../cuda/providers/default/codegen/ptx.cc | 60 +- .../default/flash_attn/flash_attn_bwd.cc | 9 +- .../default/flash_attn/flash_attn_fwd.cc | 19 +- .../default/flash_attn/kernels/flash_api.cu | 11 +- .../default/flash_attn/kernels/flash_api.h | 4 +- .../flash_attn/kernels/flash_bwd_kernel.h | 28 +- .../flash_attn/kernels/flash_fwd_kernel.h | 15 +- .../providers/default/tensor_generate/fill.cc | 45 +- runtime/lib/core/framework/op_accessor.cc | 37 + .../providers/default/kernel/fill_test.cc | 10 +- .../default/kernel/flash_attn_fwd_test.cc | 14 +- .../test/include/brt/test/common/cuda/util.h | 13 + runtime/test/test_files/fill_cuda.mlir | 4 +- runtime/test/test_files/flash_attn_fwd.mlir | 4 +- tests/numerical_test/execute.py | 5 +- tests/numerical_test/main.py | 14 +- .../mlir_tests/ops/bmm_rrr_permute_f16.mlir | 6 + .../mlir_tests/ops/concat2.mlir | 6 + .../torch_dynamo_e2e_testing/backend.py | 18 +- 127 files changed, 6472 insertions(+), 584 deletions(-) create mode 100644 compiler/doc/codegen.md create mode 100644 compiler/include/byteir/Dialect/GPU/CMakeLists.txt create mode 100644 compiler/include/byteir/Dialect/GPU/Passes.h create mode 100644 compiler/include/byteir/Dialect/GPU/Passes.td create mode 100644 compiler/include/byteir/Dialect/GPU/Transforms/Transforms.h create mode 100644 compiler/include/byteir/Dialect/Tensor/CMakeLists.txt create mode 100644 compiler/include/byteir/Dialect/Tensor/Passes.h create mode 100644 compiler/include/byteir/Dialect/Tensor/Passes.td create mode 100644 compiler/include/byteir/Dialect/Tensor/Transforms/TensorPadSpecialization.h create mode 100644 compiler/include/byteir/Pipelines/GPU/MappingForall.h create mode 100644 compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h create mode 100644 compiler/lib/Dialect/GPU/CMakeLists.txt create mode 100644 compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt create mode 100644 compiler/lib/Dialect/GPU/Transforms/ShmAllocaToWorkgroupArg.cpp create mode 100644 compiler/lib/Dialect/Tensor/Transforms/PassDetail.h create mode 100644 compiler/lib/Dialect/Tensor/Transforms/TensorPadSpecialization.cpp create mode 100644 compiler/lib/Pipelines/GPU/MappingForall.cpp create mode 100644 compiler/lib/Pipelines/GPU/ReductionCodegen.cpp create mode 100644 external/patches/AITemplate/logging.patch create mode 100644 frontends/torch-frontend/examples/demo/README.md create mode 100644 frontends/torch-frontend/examples/demo/backend.py create mode 100644 frontends/torch-frontend/examples/demo/byteir_fusible_pattern.py create mode 100644 frontends/torch-frontend/examples/demo/compile_utils.py create mode 100644 frontends/torch-frontend/examples/demo/config.py create mode 100644 frontends/torch-frontend/examples/demo/fx_match_utils.py create mode 100644 frontends/torch-frontend/examples/demo/main.py create mode 100644 frontends/torch-frontend/examples/demo/partitioners.py create mode 100644 frontends/torch-frontend/torch-frontend/python/test/test_fx_utils.py create mode 100644 tests/numerical_test/mlir_tests/ops/bmm_rrr_permute_f16.mlir create mode 100644 tests/numerical_test/mlir_tests/ops/concat2.mlir diff --git a/compiler/doc/codegen.md b/compiler/doc/codegen.md new file mode 100644 index 000000000..9299a6900 --- /dev/null +++ b/compiler/doc/codegen.md @@ -0,0 +1,245 @@ +# Codegen pipeline + +## hlo-opt + +This pass pipeline is mainly used for clustering fusion group on mhlo dialect, each fusion group was expected to fused into a single kernel in later codegen pipeline and would be outlined as a indepedent kernel function. + +- `ReductionFusionPass` reduction fusion in producer direction + +- `ElementFusionPass` elementwise/broadcast/collapse_shape/expand_shape/etc. producer-consumer bi-directional fusion + +- `FusionOutliningPass` fusion group outlining + +## linalg-tensor-opt + +### reduction codegen transformations + +``` + func.func private @Unknown0(%arg0: tensor<8192x50257xf16>) -> tensor<50257xf32> attributes {__byteir_reduction_fusion__} { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.convert %arg0 : (tensor<8192x50257xf16>) -> tensor<8192x50257xf32> + %2 = mhlo.reduce(%1 init: %0) across dimensions = [0] : (tensor<8192x50257xf32>, tensor) -> tensor<50257xf32> + reducer(%arg1: tensor, %arg2: tensor) { + %3 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %3 : tensor + } + return %2 : tensor<50257xf32> + } +``` + +This pass pipeline first convert outlined mhlo fusion group into linalg dialect and try to fuse linalg op with its producer/consumer. + +- `createLinalgElementwiseFusionExtPass` linalg fusion pass with our extension, see [linalg pass](linalg.md) for more details + +``` +func.func private @Unknown0(%arg0: tensor<8192x50257xf16>) -> tensor<50257xf32> attributes {__byteir_reduction_fusion__} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<50257xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<50257xf32>) -> tensor<50257xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<8192x50257xf16>) outs(%1 : tensor<50257xf32>) { + ^bb0(%in: f16, %out: f32): + %3 = arith.extf %in : f16 to f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<50257xf32> + return %2 : tensor<50257xf32> +} +``` + +[optional] Split grid-level reduction on `reduction` dimensions + +``` +func.func private @Unknown0(%arg0: tensor<8192x50257xf16>) -> tensor<50257xf32> attributes {__byteir_reduction_fusion__} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<50257xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<50257xf32>) -> tensor<50257xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<8192x50257xf16> into tensor<32x256x50257xf16> + %2 = tensor.empty() : tensor<32x50257xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<32x50257xf32>) -> tensor<32x50257xf32> + %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded : tensor<32x256x50257xf16>) outs(%3 : tensor<32x50257xf32>) attrs = {__grid_reduction__} { + ^bb0(%in: f16, %out: f32): + %6 = arith.extf %in : f16 to f32 + %7 = arith.addf %out, %6 : f32 + linalg.yield %7 : f32 + } -> tensor<32x50257xf32> + %5 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["reduction", "parallel"]} ins(%4 : tensor<32x50257xf32>) outs(%1 : tensor<50257xf32>) attrs = {__grid_reduction__} { + ^bb0(%in: f32, %out: f32): + %6 = arith.addf %in, %out : f32 + linalg.yield %6 : f32 + } -> tensor<50257xf32> + return %5 : tensor<50257xf32> +} +``` + +- Tiling reduction on `parallel` dimension and mapping tiled reductions to thread blocks + +``` +func.func private @Unknown0(%arg0: tensor<8192x50257xf16>) -> tensor<50257xf32> attributes {__byteir_reduction_fusion__} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<50257xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<8192x50257xf16> into tensor<32x256x50257xf16> + %1 = tensor.empty() : tensor<32x50257xf32> + %2 = scf.forall (%arg1, %arg2) in (32, 1571) shared_outs(%arg3 = %1) -> (tensor<32x50257xf32>) { + %4 = affine.min #map(%arg2) + %5 = affine.apply #map1(%arg2) + %extracted_slice = tensor.extract_slice %expanded[%arg1, 0, %5] [1, 256, %4] [1, 1, 1] : tensor<32x256x50257xf16> to tensor<256x?xf16> + %extracted_slice_0 = tensor.extract_slice %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<32x50257xf32> to tensor + %6 = linalg.fill ins(%cst : f32) outs(%extracted_slice_0 : tensor) -> tensor + %7 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<256x?xf16>) outs(%6 : tensor) { + ^bb0(%in: f16, %out: f32): + %8 = arith.extf %in : f16 to f32 + %9 = arith.addf %out, %8 : f32 + linalg.yield %9 : f32 + } -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %7 into %arg3[%arg1, %5] [1, %4] [1, 1] : tensor into tensor<32x50257xf32> + } + } {mapping = [#gpu.block, #gpu.block]} + %3 = scf.forall (%arg1) in (1571) shared_outs(%arg2 = %0) -> (tensor<50257xf32>) { + // ... + } {mapping = [#gpu.block]} + return %3 : tensor<50257xf32> +} +``` + +- Block-level reduction codegen + +``` +%2 = scf.forall (%arg1, %arg2) in (32, 1571) shared_outs(%arg3 = %1) -> (tensor<32x50257xf32>) { + %4 = affine.min #map(%arg2) + %5 = affine.apply #map1(%arg2) + %extracted_slice = tensor.extract_slice %expanded[%arg1, 0, %5] [1, 256, %4] [1, 1, 1] : tensor<32x256x50257xf16> to tensor<256x?xf16> + %6 = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<32xf32> + %7 = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<16x32xf32> + %8 = scf.forall (%arg4, %arg5) in (16, 32) shared_outs(%arg6 = %7) -> (tensor<16x32xf32>) { + %17 = affine.min #map2(%arg4) + %18 = affine.min #map3(%arg4) + %19 = affine.apply #map4(%18, %17) + %20 = affine.min #map5(%arg5, %arg2) + %21 = affine.min #map6(%arg5, %arg2) + %22 = affine.apply #map4(%21, %20) + %23 = affine.apply #map7(%21, %20) + %extracted_slice_6 = tensor.extract_slice %extracted_slice[%17, %20] [%19, %22] [1, 1] : tensor<256x?xf16> to tensor + %padded = tensor.pad %extracted_slice_6 low[0, 0] high[0, %23] { + ^bb0(%arg7: index, %arg8: index): + tensor.yield %cst : f16 + } : tensor to tensor<16x1xf16> + %extracted_slice_7 = tensor.extract_slice %arg6[%arg4, %arg5] [1, 1] [1, 1] : tensor<16x32xf32> to tensor + %collapsed = tensor.collapse_shape %padded [[0, 1]] : tensor<16x1xf16> into tensor<16xf16> + %24 = linalg.fill ins(%cst_0 : f32) outs(%extracted_slice_7 : tensor) -> tensor + %25 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["reduction"]} ins(%collapsed : tensor<16xf16>) outs(%24 : tensor) { + ^bb0(%in: f16, %out: f32): + %26 = arith.extf %in : f16 to f32 + %27 = arith.addf %out, %26 : f32 + linalg.yield %27 : f32 + } -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %25 into %arg6[%arg4, %arg5] [1, 1] [1, 1] : tensor into tensor<16x32xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + %expanded_1 = tensor.expand_shape %8 [[0, 1], [2]] : tensor<16x32xf32> into tensor<8x2x32xf32> + %9 = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<8x32xf32> + %10 = scf.forall (%arg4, %arg5) in (8, 32) shared_outs(%arg6 = %9) -> (tensor<8x32xf32>) { + // ... + } {mapping = [#gpu.thread, #gpu.thread]} + %expanded_2 = tensor.expand_shape %10 [[0, 1], [2]] : tensor<8x32xf32> into tensor<4x2x32xf32> + %11 = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<4x32xf32> + %12 = scf.forall (%arg4, %arg5) in (4, 32) shared_outs(%arg6 = %11) -> (tensor<4x32xf32>) { + // ... + } {mapping = [#gpu.thread, #gpu.thread]} + %expanded_3 = tensor.expand_shape %12 [[0, 1], [2]] : tensor<4x32xf32> into tensor<2x2x32xf32> + %13 = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<2x32xf32> + %14 = scf.forall (%arg4, %arg5) in (2, 32) shared_outs(%arg6 = %13) -> (tensor<2x32xf32>) { + // ... + } {mapping = [#gpu.thread, #gpu.thread]} + %15 = scf.forall (%arg4) in (32) shared_outs(%arg5 = %6) -> (tensor<32xf32>) { + // ... + } {mapping = [#gpu.thread]} + %extracted_slice_4 = tensor.extract_slice %15[0] [%4] [1] : tensor<32xf32> to tensor + %extracted_slice_5 = tensor.extract_slice %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<32x50257xf32> to tensor + %16 = scf.forall (%arg4) in (512) shared_outs(%arg5 = %extracted_slice_5) -> (tensor) { + // ... + } {mapping = [#gpu.linear]} + scf.forall.in_parallel { + tensor.parallel_insert_slice %16 into %arg3[%arg1, %5] [1, %4] [1, 1] : tensor into tensor<32x50257xf32> + } +} {mapping = [#gpu.block, #gpu.block]} +``` + +- Detensorize scalar linalg ops to arith ops and specialize `tensor.pad` + +``` +%2 = scf.forall (%arg1, %arg2) in (32, 1571) shared_outs(%arg3 = %1) -> (tensor<32x50257xf32>) { + %4 = affine.min #map(%arg2) + %5 = affine.apply #map1(%arg2) + %6 = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<32xf32> + %7 = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<16x32xf32> + %8 = scf.forall (%arg4, %arg5) in (16, 32) shared_outs(%arg6 = %7) -> (tensor<16x32xf32>) { + %17 = affine.min #map2(%arg5, %arg2) + %18 = affine.min #map3(%arg5, %arg2) + %19 = affine.apply #map4(%18, %17) + %20 = arith.cmpi ugt, %19, %c0 : index + %21 = scf.if %20 -> (f16) { + %84 = affine.apply #map5(%arg4) + %85 = affine.apply #map6(%arg2)[%17] + %extracted = tensor.extract %expanded[%arg1, %84, %85] : tensor<32x256x50257xf16> + scf.yield %extracted : f16 + } else { + scf.yield %cst : f16 + } + // ... + %78 = arith.extf %77 : f16 to f32 + %79 = arith.addf %75, %78 : f32 + %80 = arith.cmpi ugt, %19, %c0 : index + %81 = scf.if %80 -> (f16) { + %84 = affine.apply #map21(%arg4) + %85 = affine.apply #map6(%arg2)[%17] + %extracted = tensor.extract %expanded[%arg1, %84, %85] : tensor<32x256x50257xf16> + scf.yield %extracted : f16 + } else { + scf.yield %cst : f16 + } + %82 = arith.extf %81 : f16 to f32 + %83 = arith.addf %79, %82 : f32 + %extracted_slice_5 = tensor.extract_slice %arg6[%arg4, %arg5] [1, 1] [1, 1] : tensor<16x32xf32> to tensor + %inserted = tensor.insert %83 into %extracted_slice_5[] : tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %inserted into %arg6[%arg4, %arg5] [1, 1] [1, 1] : tensor into tensor<16x32xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + + // ... + %extracted_slice = tensor.extract_slice %15[0] [%4] [1] : tensor<32xf32> to tensor + %extracted_slice_4 = tensor.extract_slice %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<32x50257xf32> to tensor + %16 = scf.forall (%arg4) in (512) shared_outs(%arg5 = %extracted_slice_4) -> (tensor) { + %17 = affine.min #map22(%arg4)[%4] + %18 = affine.max #map23(%17) + %19 = affine.apply #map24(%arg4)[%4] + %extracted_slice_5 = tensor.extract_slice %extracted_slice[%19] [%18] [1] : tensor to tensor + %extracted_slice_6 = tensor.extract_slice %arg5[%19] [%18] [1] : tensor to tensor + %20 = linalg.copy {__byteir_gpu_tile_block_reduction_10} ins(%extracted_slice_5 : tensor) outs(%extracted_slice_6 : tensor) -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %20 into %arg5[%19] [%18] [1] : tensor into tensor + } + } {mapping = [#gpu.linear]} + scf.forall.in_parallel { + tensor.parallel_insert_slice %16 into %arg3[%arg1, %5] [1, %4] [1, 1] : tensor into tensor<32x50257xf32> + } +} {mapping = [#gpu.block, #gpu.block]} +``` + +- `structured.split_reduction` split reduction op along `reduction` dimension for increasing parallelism + +- `structured.tile_to_forall_op` tile reduction op along `parallel` dimensions to `forall` op and mapping to block/linear/thread + +- `structured.fuse_into_containing_op` fuse init and pad operands into `scf.forall` + +- `structured.annotate` attach any attribute to target ops, used to annotate reduction op and attach memory space to `allot_tensor` + +- `structured.tile` tile reduction op along `reduction` dimension to sequential for loop + +- `structured.detensorize` use to inline computation region of linalg op which operands have scalar tensor type + +- `LinalgCollapseLoopsPass` collapse consecutive `parallel` and `reduction` loops, this pass could work on both tensor and memref + +- `TensorPadSpecializationPass` specialize `tensor.extract` of pad op to conditional read diff --git a/compiler/include/byteir/Analysis/UseRange.h b/compiler/include/byteir/Analysis/UseRange.h index 07b5588ae..704afcf16 100644 --- a/compiler/include/byteir/Analysis/UseRange.h +++ b/compiler/include/byteir/Analysis/UseRange.h @@ -104,9 +104,24 @@ class UserangeAnalysis { using UsePosition = std::pair; using UsePositionList = std::vector; + using AllocsIterator = mlir::bufferization::BufferPlacementAllocs:: + AllocEntryList::const_iterator; + using AllocsIteratorRange = llvm::iterator_range; + UserangeAnalysis(Liveness *liveness) : liveness(liveness) {} UserangeAnalysis(mlir::Operation *op, Liveness *liveness, const mlir::bufferization::BufferPlacementAllocs &allocs, + const mlir::BufferViewFlowAnalysis &aliases) + : UserangeAnalysis(op, liveness, make_range(allocs.begin(), allocs.end()), + aliases) {} + UserangeAnalysis( + mlir::Operation *op, Liveness *liveness, + const mlir::bufferization::BufferPlacementAllocs::AllocEntryList &allocs, + const mlir::BufferViewFlowAnalysis &aliases) + : UserangeAnalysis(op, liveness, make_range(allocs.begin(), allocs.end()), + aliases) {} + UserangeAnalysis(mlir::Operation *op, Liveness *liveness, + AllocsIteratorRange &&allocs, const mlir::BufferViewFlowAnalysis &aliases); virtual ~UserangeAnalysis() {} diff --git a/compiler/include/byteir/Conversion/FuncToByre/FuncToByre.h b/compiler/include/byteir/Conversion/FuncToByre/FuncToByre.h index 230ea0e1a..92e29fbac 100644 --- a/compiler/include/byteir/Conversion/FuncToByre/FuncToByre.h +++ b/compiler/include/byteir/Conversion/FuncToByre/FuncToByre.h @@ -27,9 +27,15 @@ class ModuleOp; void populateFuncToByreTensorPattern(RewritePatternSet &patterns, bool appendArgTypes); +void populateGPULaunchFuncToByrePattern(RewritePatternSet &patterns, + bool useBarePtrCallConv); + std::unique_ptr> createConvertFuncToByreTensorPass(bool appendArgTypes = false); +std::unique_ptr +createConvertGPULaunchFuncToByrePass(bool useBarePtrCallConv = false); + } // namespace mlir #endif // BYTEIR_CONVERSION_FUNCTOBYRE_FUNCTOBYRE_H diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index 79c6acbb8..e5ee1c148 100644 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -253,6 +253,9 @@ def CollectGPUKernel : Pass<"collect-gpu-kernel", "ModuleOp"> { Option<"moduleName", "module-name", "std::string", /*default=*/"\"unified\"", "Optional name for GPUModule to put all gpu kernels">, + Option<"removeHost", "remove-host", "bool", + /*default=*/"true", + "Whether to remove host part">, ]; } @@ -349,6 +352,26 @@ def ConvertFuncToByreTensor : Pass<"func-to-byre-tensor", "ModuleOp"> { } +//===----------------------------------------------------------------------===// +// FuncToByreTensor +//===----------------------------------------------------------------------===// + +def ConvertGPULaunchFuncToByre : Pass<"gpu-launch-func-to-byre"> { + let summary = "Convert gpu.launch_func op to byre compute op."; + let constructor = "mlir::createConvertGPULaunchFuncToByrePass()"; + let dependentDialects = [ + "mlir::byre::ByreDialect", + "mlir::gpu::GPUDialect" + ]; + + let options = [ + Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool", + /*default=*/"false", + "Replace memref arguments in GPU functions with bare pointers." + "All memrefs must have static shape">, + ]; +} + //===----------------------------------------------------------------------===// // MemrefToByre //===----------------------------------------------------------------------===// @@ -398,6 +421,9 @@ def MemrefCopyToLinalgPass : Option<"attachAttr", "attach-attr", "std::string", /*default=*/"", "An optional unit attribute attaching on target functions: ">, + Option<"outlining", "outlining", "bool", + /*default=*/"true", + "Whether to outline the copy op to a new function">, ]; } diff --git a/compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h b/compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h index eaed2dbcc..8e64ce9d3 100644 --- a/compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h +++ b/compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h @@ -55,7 +55,8 @@ std::unique_ptr> createLinalgExtToLinalgPass(); std::unique_ptr> createMemrefCopyToLinalgPass(std::string anchorTag = "", - std::string attachAttr = ""); + std::string attachAttr = "", + bool outlining = true); } // namespace mlir diff --git a/compiler/include/byteir/Conversion/ToPTX/ToPTX.h b/compiler/include/byteir/Conversion/ToPTX/ToPTX.h index de932d857..c5185df6c 100644 --- a/compiler/include/byteir/Conversion/ToPTX/ToPTX.h +++ b/compiler/include/byteir/Conversion/ToPTX/ToPTX.h @@ -33,7 +33,8 @@ createGenPTXConfigPass(bool useBarePtrCallConv = false); // TODO move to general GPU std::unique_ptr> -createCollectGPUKernelPass(const std::string &name = "unified"); +createCollectGPUKernelPass(const std::string &name = "unified", + bool removeHost = true); } // namespace mlir diff --git a/compiler/include/byteir/Dialect/CMakeLists.txt b/compiler/include/byteir/Dialect/CMakeLists.txt index c7113baae..3e8627a7e 100644 --- a/compiler/include/byteir/Dialect/CMakeLists.txt +++ b/compiler/include/byteir/Dialect/CMakeLists.txt @@ -3,11 +3,13 @@ add_subdirectory(Affine) add_subdirectory(Byre) add_subdirectory(Cat) add_subdirectory(Ccl) +add_subdirectory(GPU) add_subdirectory(Lace) add_subdirectory(Linalg) add_subdirectory(MemRef) add_subdirectory(mhlo) add_subdirectory(SCF) add_subdirectory(Shape) +add_subdirectory(Tensor) add_subdirectory(Transform) add_subdirectory(Vector) diff --git a/compiler/include/byteir/Dialect/Cat/IR/CatOps.td b/compiler/include/byteir/Dialect/Cat/IR/CatOps.td index be953e7a6..c8ef07b19 100644 --- a/compiler/include/byteir/Dialect/Cat/IR/CatOps.td +++ b/compiler/include/byteir/Dialect/Cat/IR/CatOps.td @@ -281,6 +281,17 @@ def Cat_GemmRCRPermuteOp : Cat_Op<"gemm_rcr_permute", [Cat_CatOpInterface, Pure] let hasVerifier = 1; } +def Cat_GemmRRRPermuteOp : Cat_Op<"gemm_rrr_permute", [Cat_CatOpInterface, Pure]> { + let summary = "gemm_rrr + permute0213 operator, output layout is [m / t1, t1, t2, n / t2]"; + let arguments = (ins AnyTensor : $lhs, + AnyTensor : $rhs, + I64Attr : $t1, + I64Attr : $t2); + let results = (outs AnyTensor : $output); + + let hasVerifier = 1; +} + def Cat_LayerNormOp : Cat_Op<"layernorm", [Cat_CatOpInterface, Pure]> { let summary = "layernorm operator"; let arguments = (ins AnyTensor : $input, diff --git a/compiler/include/byteir/Dialect/GPU/CMakeLists.txt b/compiler/include/byteir/Dialect/GPU/CMakeLists.txt new file mode 100644 index 000000000..53b17ff2e --- /dev/null +++ b/compiler/include/byteir/Dialect/GPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name ByteIRGPU) +add_public_tablegen_target(ByteIRGPUPassIncGen) diff --git a/compiler/include/byteir/Dialect/GPU/Passes.h b/compiler/include/byteir/Dialect/GPU/Passes.h new file mode 100644 index 000000000..6a86b80b6 --- /dev/null +++ b/compiler/include/byteir/Dialect/GPU/Passes.h @@ -0,0 +1,37 @@ +//===- Passes.h ----------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_GPU_PASSES_H +#define BYTEIR_DIALECT_GPU_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace gpu { +class GPUFuncOp; +} // namespace gpu + +#define GEN_PASS_DECL +#include "byteir/Dialect/GPU/Passes.h.inc" + +/// Generate the code for registering transforms passes. +#define GEN_PASS_REGISTRATION +#include "byteir/Dialect/GPU/Passes.h.inc" + +} // namespace mlir + +#endif // BYTEIR_DIALECT_GPU_PASSES_H diff --git a/compiler/include/byteir/Dialect/GPU/Passes.td b/compiler/include/byteir/Dialect/GPU/Passes.td new file mode 100644 index 000000000..862df14a2 --- /dev/null +++ b/compiler/include/byteir/Dialect/GPU/Passes.td @@ -0,0 +1,36 @@ +//===- Passes.td - Transforms pass definition file -------*--- tablegen -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + + +#ifndef BYTEIR_DIALECT_GPU_PASSES +#define BYTEIR_DIALECT_GPU_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// ShmAllocaToWorkgroupArg +//===----------------------------------------------------------------------===// + +def ShmAllocaToWorkgroupArg : Pass<"shm-alloca-to-workgroup-arg", "gpu::GPUModuleOp"> { + let summary = "Hoist shared memory alloca in gpu kernel to workgroup argument"; + let dependentDialects = [ + "gpu::GPUDialect", + "memref::MemRefDialect" + ]; +} + +#endif // BYTEIR_DIALECT_GPU_PASSES diff --git a/compiler/include/byteir/Dialect/GPU/Transforms/Transforms.h b/compiler/include/byteir/Dialect/GPU/Transforms/Transforms.h new file mode 100644 index 000000000..042d045dd --- /dev/null +++ b/compiler/include/byteir/Dialect/GPU/Transforms/Transforms.h @@ -0,0 +1,32 @@ +//===- Transforms.h -------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H +#define BYTEIR_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" + +namespace mlir { +namespace gpu { + +// hoist shared memory alloca in gpu kernel to workgroup arg +void hoistShmAllocaToWorkgroup(gpu::GPUFuncOp func); + +} // namespace gpu +} // namespace mlir + +#endif // BYTEIR_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/Linalg/Passes.td b/compiler/include/byteir/Dialect/Linalg/Passes.td index 81c411b53..5393ef410 100644 --- a/compiler/include/byteir/Dialect/Linalg/Passes.td +++ b/compiler/include/byteir/Dialect/Linalg/Passes.td @@ -166,6 +166,16 @@ def LinalgCollapseLoops : Pass<"linalg-collapse-loops", "func::FuncOp"> { "tensor::TensorDialect", "memref::MemRefDialect" ]; + + let options = [ + Option<"iteratorType", "iterator-type", "mlir::utils::IteratorType", + /*default=*/"mlir::utils::IteratorType::parallel", "iterator type", + [{llvm::cl::values( + clEnumValN(mlir::utils::IteratorType::parallel, "parallel", + "parallel iterator type"), + clEnumValN(mlir::utils::IteratorType::reduction, "reduction", + "reduction iterator type"))}]>, + ]; } //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.td b/compiler/include/byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.td index ac7d80126..dc918c661 100644 --- a/compiler/include/byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.td +++ b/compiler/include/byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.td @@ -72,6 +72,18 @@ def CollapseDimsOp : Op]> { + let description = [{ + Detensorize linalg ops. + }]; + + let arguments = (ins PDL_Operation:$target); + + let assemblyFormat = "$target attr-dict"; +} + def FoldUnitExtentDimsOp : Op]> { @@ -245,4 +257,31 @@ def FuseOperandsOp : Op { + let description = [{ + insert_slice_to_copy_ext extension. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let builders = [ + OpBuilder<(ins "Value":$target)>, + ]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + + #endif // BYTEIR_DIALECT_LINALG_TRANSFORMOPS_LINALG_EXT_TRANSFORMOPS \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/Linalg/Transforms/LinalgCollapseLoops.h b/compiler/include/byteir/Dialect/Linalg/Transforms/LinalgCollapseLoops.h index 556a43c41..b5ef16af4 100644 --- a/compiler/include/byteir/Dialect/Linalg/Transforms/LinalgCollapseLoops.h +++ b/compiler/include/byteir/Dialect/Linalg/Transforms/LinalgCollapseLoops.h @@ -18,6 +18,7 @@ #ifndef BYTEIR_DIALECT_LINALG_TRANSFORMS_LINALGCOLLAPSELOOPS_H #define BYTEIR_DIALECT_LINALG_TRANSFORMS_LINALGCOLLAPSELOOPS_H +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Pass/Pass.h" #include @@ -26,7 +27,11 @@ namespace func { class FuncOp; } // namespace func -std::unique_ptr> createLinalgCollapseLoops(); +#define GEN_PASS_DECL_LINALGCOLLAPSELOOPS +#include "byteir/Dialect/Linalg/Passes.h.inc" + +std::unique_ptr> createLinalgCollapseLoops( + utils::IteratorType iteratorType = utils::IteratorType::parallel); } // namespace mlir diff --git a/compiler/include/byteir/Dialect/Tensor/CMakeLists.txt b/compiler/include/byteir/Dialect/Tensor/CMakeLists.txt new file mode 100644 index 000000000..a4f8266f9 --- /dev/null +++ b/compiler/include/byteir/Dialect/Tensor/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name ByteIRTensor) +add_public_tablegen_target(ByteIRTensorPassIncGen) diff --git a/compiler/include/byteir/Dialect/Tensor/Passes.h b/compiler/include/byteir/Dialect/Tensor/Passes.h new file mode 100644 index 000000000..be0c09f87 --- /dev/null +++ b/compiler/include/byteir/Dialect/Tensor/Passes.h @@ -0,0 +1,31 @@ +//===- Passes.h ---------------------------------------------------- C++ --===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_TENSOR_PASSES_H +#define BYTEIR_DIALECT_TENSOR_PASSES_H + +#include "byteir/Dialect/Tensor/Transforms/TensorPadSpecialization.h" + +namespace mlir { + +/// Generate the code for registering transforms passes. +#define GEN_PASS_REGISTRATION +#include "byteir/Dialect/Tensor/Passes.h.inc" + +} // namespace mlir + +#endif // BYTEIR_DIALECT_TENSOR_PASSES_H diff --git a/compiler/include/byteir/Dialect/Tensor/Passes.td b/compiler/include/byteir/Dialect/Tensor/Passes.td new file mode 100644 index 000000000..cdfa73c4d --- /dev/null +++ b/compiler/include/byteir/Dialect/Tensor/Passes.td @@ -0,0 +1,35 @@ +//===- Passes.td - Transforms pass definition file -------*--- tablegen -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_TENSOR_PASSES +#define BYTEIR_DIALECT_TENSOR_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// TensorPadSpecialization +//===----------------------------------------------------------------------===// + +def TensorPadSpecialization : Pass<"tensor-pad-specialization", ""> { + let summary = "Specialize tensor.pad op"; + let constructor = "mlir::createTensorPadSpecializationPass()"; + let dependentDialects = [ + "scf::SCFDialect", + ]; +} + +#endif // BYTEIR_DIALECT_TENSOR_PASSES diff --git a/compiler/include/byteir/Dialect/Tensor/Transforms/TensorPadSpecialization.h b/compiler/include/byteir/Dialect/Tensor/Transforms/TensorPadSpecialization.h new file mode 100644 index 000000000..72f38cd03 --- /dev/null +++ b/compiler/include/byteir/Dialect/Tensor/Transforms/TensorPadSpecialization.h @@ -0,0 +1,30 @@ +//===- TensorPadSpecialization.h ---------------------------------- C++ --===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_SHAPE_TRANSFORMS_TENSORPADSPECIALIZATION_H +#define BYTEIR_DIALECT_SHAPE_TRANSFORMS_TENSORPADSPECIALIZATION_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { + +std::unique_ptr createTensorPadSpecializationPass(); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_SHAPE_TRANSFORMS_TENSORPADSPECIALIZATION_H diff --git a/compiler/include/byteir/Dialect/Transform/Passes.td b/compiler/include/byteir/Dialect/Transform/Passes.td index 6e82c4b3f..49a471a71 100644 --- a/compiler/include/byteir/Dialect/Transform/Passes.td +++ b/compiler/include/byteir/Dialect/Transform/Passes.td @@ -34,6 +34,24 @@ def TransformDialectInterpreter : Pass<"transform-dialect-interpreter", "ModuleO ]; } +//===----------------------------------------------------------------------===// +// DetensorizeTransformationInsertion +//===----------------------------------------------------------------------===// + +def DetensorizeTransformInsertion : Pass<"insert-detensorize-transform", "ModuleOp"> { + let summary = "Insert detensorize transform IR to functions."; + let constructor = "mlir::createDetensorizeTransformInsertionPass()"; + let options = [ + Option<"funcAnchorAttr", "func-anchor", "std::string", + /*default=*/"", + "An optional Unit attribute anchoring on target functions.">, + Option<"matchPrefix", "match-prefix", "std::string", + /*default=*/"\"__byteir_detensorize\"", + "An optional match prefix attribute on target ops.">, + ]; +} + + //===----------------------------------------------------------------------===// // FuseExtTransformInsertion //===----------------------------------------------------------------------===// @@ -60,4 +78,22 @@ def FuseExtTransformInsertion : Pass<"insert-fuse-ext-transform", "ModuleOp"> { ]; } +//===----------------------------------------------------------------------===// +// RewriteInDPSTransformInsertion +//===----------------------------------------------------------------------===// + +def RewriteInDPSTransformInsertion : Pass<"insert-rewrite-in-dps-transform", "ModuleOp"> { + let summary = "Insert rewrite in destination-passing-style transform IR to functions."; + let constructor = "mlir::createRewriteInDPSTransformInsertionPass()"; + let options = [ + Option<"funcAnchorAttr", "func-anchor", "std::string", + /*default=*/"", + "An optional Unit attribute anchoring on target functions.">, + Option<"matchPrefix", "match-prefix", "std::string", + /*default=*/"\"__byteir_detensorize\"", + "An optional match prefix attribute on target ops.">, + ]; +} + + #endif // BYTEIR_DIALECT_TRANSFORM_PASSES diff --git a/compiler/include/byteir/Dialect/Transform/Transforms/TransformInsertion.h b/compiler/include/byteir/Dialect/Transform/Transforms/TransformInsertion.h index 931b6d0bc..5741663ac 100644 --- a/compiler/include/byteir/Dialect/Transform/Transforms/TransformInsertion.h +++ b/compiler/include/byteir/Dialect/Transform/Transforms/TransformInsertion.h @@ -37,12 +37,22 @@ struct TransformInsertionConfig { std::unique_ptr> createGenericTransformInsertionPass(const TransformInsertionConfig &config); +std::unique_ptr> +createDetensorizeTransformInsertionPass( + const std::string &funcAnchor = "", + const std::string &matchPrefix = "__byteir_detensorize"); + std::unique_ptr> createFuseExtTransformInsertionPass( const std::string &funcAnchor = "", const std::string &matchPrefix = "unknown", const std::string &tileSizeAttrName = "", const std::string &tileInterchangeAttrName = "", const bool keepIntermediates = false); + +std::unique_ptr> +createRewriteInDPSTransformInsertionPass( + const std::string &funcAnchor = "", + const std::string &matchPrefix = "__byteir_rewrite_in_dps"); } // namespace mlir #endif // BYTEIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINSERTION_H \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h b/compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h index fa0a1b94c..78e3aa1bd 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h +++ b/compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h @@ -45,6 +45,7 @@ class ReshapeOp; class MulOp; class SliceOp; class ReverseOp; +class GatherOp; // Most of these will push back to upstream // So this file only includes patterns, not a pass. @@ -143,6 +144,8 @@ LogicalResult simplifyTransposeReshapeTranspose(mhlo::TransposeOp op, LogicalResult foldReverseWithConstant(mhlo::ReverseOp op, PatternRewriter &rewriter); +LogicalResult foldGatherWithInput(mhlo::GatherOp op, PatternRewriter &rewriter); + // populate canonicalizeExt patterns void populateCanonicalizeExtPatterns(RewritePatternSet &patterns, MLIRContext *context, diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h b/compiler/include/byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h index e84497390..b36ebc426 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h +++ b/compiler/include/byteir/Dialect/mhlo/Transforms/ConvertOpToCustomCall.h @@ -1,4 +1,4 @@ -//===- ConvertRngToCustomCall.h -------------------------------*--- C++ -*-===// +//===- ConvertOpToCustomCall.h --------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,6 +27,8 @@ class ModuleOp; void populateRngPatternToCustomCall(RewritePatternSet &patterns); +void populateFlashFwdRewritePattern(RewritePatternSet &patterns); + std::unique_ptr> createConvertOpToCustomCallPass(llvm::StringRef anchor = ""); diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/GenericFusionCommon.h b/compiler/include/byteir/Dialect/mhlo/Transforms/GenericFusionCommon.h index be1a10a93..df10ecfdc 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/GenericFusionCommon.h +++ b/compiler/include/byteir/Dialect/mhlo/Transforms/GenericFusionCommon.h @@ -41,6 +41,7 @@ struct GenericFuserConfig { std::function fuse_trigger; std::function fuse_with; std::function valid_single_op; + std::function valid_fusion_pattern; }; //===----------------------------------------------------------------------===// @@ -115,6 +116,9 @@ class GenericFusionPass : public GenericFusionBase { for (auto it = plan.rbegin(); it != plan.rend(); ++it) { auto &pattern = *it; + if (!fuse_config.valid_fusion_pattern(pattern)) + continue; + if (pattern.size() > 1) { applyMhloFusionPattern(pattern, fuse_config.fuse_attr); } else if (this->clusterSingleOp.getValue()) { diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h b/compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h index f49315bbe..6a0e10cb7 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h +++ b/compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h @@ -46,6 +46,10 @@ constexpr StringRef getByteIRMatmulEpilogueFusionAttrName() { return "__byteir_matmul_epilogue_fusion__"; } +constexpr StringRef getByteIRReductionFusionAttrName() { + return "__byteir_reduction_fusion__"; +} + constexpr StringRef getByteIRTrivialFusionAttrName() { return "__byteir_trivial_fusion__"; } @@ -102,6 +106,8 @@ std::unique_ptr> createTrivialFusionPass(); std::unique_ptr> createHloAggressiveFusionPass(); +std::unique_ptr> createReductionFusionPass(); + } // namespace mlir #endif // BYTEIR_DIALECT_MHLO_TRANSFORMS_HLOFUSER_H diff --git a/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h b/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h index af68cf38a..83a53e329 100644 --- a/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h +++ b/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h @@ -85,6 +85,14 @@ constexpr llvm::StringRef getRngUniformName() { return CUSTOM_CALL_NAME_PREFIX "rng_uniform"; } +constexpr llvm::StringRef getFlashAttnFwdName() { + return CUSTOM_CALL_NAME_PREFIX "flash_attn_fwd"; +} + +constexpr llvm::StringRef getFlashAttnBwdName() { + return CUSTOM_CALL_NAME_PREFIX "flash_attn_bwd"; +} + constexpr llvm::StringRef getDynamicPartitionName() { return TF_NAME_PREFIX "DynamicPartition"; } diff --git a/compiler/include/byteir/Pipelines/GPU/MappingForall.h b/compiler/include/byteir/Pipelines/GPU/MappingForall.h new file mode 100644 index 000000000..202cc1e40 --- /dev/null +++ b/compiler/include/byteir/Pipelines/GPU/MappingForall.h @@ -0,0 +1,53 @@ +//===- MappingForall.h ---------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_PIPELINES_GPU_MAPPING_FORALL_H +#define BYTEIR_PIPELINES_GPU_MAPPING_FORALL_H + +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Pass/PassRegistry.h" + +namespace mlir { +struct GPUMappingForallOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_split_grid_reduction")}; + // TODO: option for grid/block dims hint +}; + +void createGPUMappingForallTransform(OpPassManager &pm, + const GPUMappingForallOptions &options); + +inline void registerGPUMappingForallPipelines() { + PassPipelineRegistration( + "insert-gpu-mapping-forall-transform", + "Insert transformation IR to mapping forall to corresponding blocks and " + "threads", + createGPUMappingForallTransform); +} + +} // namespace mlir + +#endif // BYTEIR_PIPELINES_GPU_MAPPING_FORALL_H diff --git a/compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h b/compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h new file mode 100644 index 000000000..7aea80d51 --- /dev/null +++ b/compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h @@ -0,0 +1,153 @@ +//===- ReductionCodegen.h -----------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_PIPELINES_GPU_REDUCTION_CODEGEN_H +#define BYTEIR_PIPELINES_GPU_REDUCTION_CODEGEN_H + +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Pass/PassRegistry.h" + +namespace mlir { +struct GPUSplitGridReductionOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_split_grid_reduction")}; + Option splitFactor{*this, "split-factor", + llvm::cl::desc("split factor"), + llvm::cl::init(32)}; +}; + +struct GPUTileGridReductionOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_tile_grid_reduction")}; + Option warpSize{*this, "warp-size", llvm::cl::desc("warp size"), + llvm::cl::init(32)}; + Option blockSize{*this, "block-size", llvm::cl::desc("block size"), + llvm::cl::init(256)}; + Option usingForall{*this, "using-forall", + llvm::cl::desc("using forall"), + llvm::cl::init(true)}; +}; + +struct GPUSplitBlockReductionOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_split_block_reduction")}; + Option splitFactor{*this, "split-factor", + llvm::cl::desc("split factor"), + llvm::cl::init(32)}; + Option warpSize{*this, "warp-size", llvm::cl::desc("warp size"), + llvm::cl::init(32)}; +}; + +struct GPUTileBlockReductionOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_tile_block_reduction")}; + Option warpSize{*this, "warp-size", llvm::cl::desc("warp size"), + llvm::cl::init(32)}; + Option blockSize{*this, "block-size", llvm::cl::desc("block size"), + llvm::cl::init(256)}; + Option usingForall{*this, "using-forall", + llvm::cl::desc("using forall"), + llvm::cl::init(true)}; +}; + +struct GPUTileThreadReductionOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_tile_thread_reduction")}; +}; + +void createGPUSplitGridReductionTransform( + OpPassManager &pm, const GPUSplitGridReductionOptions &options); +void createGPUTileGridReductionTransform( + OpPassManager &pm, const GPUTileGridReductionOptions &options); +void createGPUSplitBlockReductionTransform( + OpPassManager &pm, const GPUSplitBlockReductionOptions &options); +void createGPUTileBlockReductionTransform( + OpPassManager &pm, const GPUTileBlockReductionOptions &options); +void createGPUTileThreadReductionTransform( + OpPassManager &pm, const GPUTileThreadReductionOptions &options); + +inline void registerGPUReductionCodegenPipelines() { + PassPipelineRegistration( + "insert-gpu-split-grid-reduction-transform", + "Insert transformation IR to split linalg reduction op", + createGPUSplitGridReductionTransform); + + PassPipelineRegistration( + "insert-gpu-tile-grid-reduction-transform", + "Insert transformation IR to tile linalg reduction op", + createGPUTileGridReductionTransform); + + PassPipelineRegistration( + "insert-gpu-split-block-reduction-transform", + "Insert transformation IR to split linalg reduction op", + createGPUSplitBlockReductionTransform); + + PassPipelineRegistration( + "insert-gpu-tile-block-reduction-transform", + "Insert transformation IR to tile linalg reduction op", + createGPUTileBlockReductionTransform); + + PassPipelineRegistration( + "insert-gpu-tile-thread-reduction-transform", + "Insert transformation IR to tile linalg reduction op", + createGPUTileThreadReductionTransform); +} + +} // namespace mlir + +#endif // BYTEIR_PIPELINES_GPU_REDUCTION_CODEGEN_H diff --git a/compiler/include/byteir/Pipelines/InitAllPipelines.h b/compiler/include/byteir/Pipelines/InitAllPipelines.h index 0cdf1e7ce..2a653f898 100644 --- a/compiler/include/byteir/Pipelines/InitAllPipelines.h +++ b/compiler/include/byteir/Pipelines/InitAllPipelines.h @@ -35,7 +35,9 @@ #include "byteir/Pipelines/GPU/ElementwiseCodegen.h" #include "byteir/Pipelines/GPU/GPUOpt.h" #include "byteir/Pipelines/GPU/LinalgMemrefGPU.h" +#include "byteir/Pipelines/GPU/MappingForall.h" #include "byteir/Pipelines/GPU/NVVMCodegen.h" +#include "byteir/Pipelines/GPU/ReductionCodegen.h" #include "byteir/Pipelines/Host/Codegen.h" #include "byteir/Pipelines/Host/HostOpt.h" @@ -65,6 +67,8 @@ inline void registerAllByteIRGPUPipelines() { registerLinalgMemrefGPUPipeline(); registerMatmulEpilogueGPUPipeline(); registerGPUElementwiseCodegenPipelines(); + registerGPUReductionCodegenPipelines(); + registerGPUMappingForallPipelines(); } inline void registerAllByteIRHostPipelines() { diff --git a/compiler/include/byteir/Transforms/MemoryPlanning.h b/compiler/include/byteir/Transforms/MemoryPlanning.h index 74a83b88c..c1ad2a181 100644 --- a/compiler/include/byteir/Transforms/MemoryPlanning.h +++ b/compiler/include/byteir/Transforms/MemoryPlanning.h @@ -23,18 +23,19 @@ #include namespace mlir { +class FunctionOpInterface; class Value; namespace func { class FuncOp; } // namespace func -std::unique_ptr> createMemoryPlanningPass(); +std::unique_ptr> createMemoryPlanningPass(); /// couldReuseBuffer is a user provided callback which receives a Value as /// parameter and returns whether the allocation corresponding to the Value can /// be reused -std::unique_ptr> -createMemoryPlanningPass(size_t alignment, +std::unique_ptr> +createMemoryPlanningPass(size_t alignment, bool alloca, size_t memSpace, std::function couldReuseAllocation); } // namespace mlir diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index 4c6a9cb1b..8ac2f3e7c 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -237,7 +237,7 @@ def LoopTag : Pass<"loop-tag", "func::FuncOp"> { //===----------------------------------------------------------------------===// // Memory planning //===----------------------------------------------------------------------===// -def MemoryPlanning: Pass<"memory-planning", "mlir::func::FuncOp"> { +def MemoryPlanning: InterfacePass<"memory-planning", "mlir::FunctionOpInterface"> { let summary = "Pass to perform static memory planning"; let constructor = "mlir::createMemoryPlanningPass()"; let dependentDialects = [ diff --git a/compiler/lib/Analysis/UseRange.cpp b/compiler/lib/Analysis/UseRange.cpp index f8a6e1076..88708c154 100644 --- a/compiler/lib/Analysis/UseRange.cpp +++ b/compiler/lib/Analysis/UseRange.cpp @@ -383,10 +383,9 @@ void UseInterval::mergeAndEraseContiguousIntervals( iter = interval.erase(std::next(iter), next); } -UserangeAnalysis::UserangeAnalysis( - Operation *op, byteir::Liveness *liveness, - const bufferization::BufferPlacementAllocs &allocs, - const BufferViewFlowAnalysis &aliases) +UserangeAnalysis::UserangeAnalysis(Operation *op, byteir::Liveness *liveness, + AllocsIteratorRange &&allocs, + const BufferViewFlowAnalysis &aliases) : liveness(liveness) { // Walk over all operations and map them to an ID. op->walk([&](Operation *operation) { diff --git a/compiler/lib/CAPI/CMakeLists.txt b/compiler/lib/CAPI/CMakeLists.txt index d8bc56463..1ae8a97f2 100644 --- a/compiler/lib/CAPI/CMakeLists.txt +++ b/compiler/lib/CAPI/CMakeLists.txt @@ -33,6 +33,7 @@ add_mlir_public_c_api_library(ByteIRCAPI # dialect specific passes ByteIRAffinePasses ByteIRByrePasses + ByteIRGPUPasses ByteIRLinalgPasses ByteIRMemRefPasses ByteIRMhloPasses diff --git a/compiler/lib/CAPI/Passes.cpp b/compiler/lib/CAPI/Passes.cpp index e686dab22..6707875c9 100644 --- a/compiler/lib/CAPI/Passes.cpp +++ b/compiler/lib/CAPI/Passes.cpp @@ -21,10 +21,12 @@ #include "byteir/Dialect/Ace/Passes.h" #include "byteir/Dialect/Affine/Passes.h" #include "byteir/Dialect/Byre/Passes.h" +#include "byteir/Dialect/GPU/Passes.h" #include "byteir/Dialect/Linalg/Passes.h" #include "byteir/Dialect/MemRef/Passes.h" #include "byteir/Dialect/SCF/Passes.h" #include "byteir/Dialect/Shape/Passes.h" +#include "byteir/Dialect/Tensor/Passes.h" #include "byteir/Dialect/Transform/Passes.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/InitAllPipelines.h" @@ -45,11 +47,13 @@ void byteirRegisterAllPasses() { registerByteIRAcePasses(); registerByteIRAffinePasses(); registerByteIRByrePasses(); + registerByteIRGPUPasses(); registerByteIRLinalgPasses(); registerByteIRMemRefPasses(); registerByteIRMhloPassesExt(); registerByteIRSCFPasses(); registerByteIRShapePasses(); + registerByteIRTensorPasses(); registerByteIRTransformPasses(); // pipelines diff --git a/compiler/lib/Conversion/FuncToByre/FuncToByre.cpp b/compiler/lib/Conversion/FuncToByre/FuncToByre.cpp index 45e87ca16..bd1dce952 100644 --- a/compiler/lib/Conversion/FuncToByre/FuncToByre.cpp +++ b/compiler/lib/Conversion/FuncToByre/FuncToByre.cpp @@ -19,7 +19,9 @@ #include "byteir/Dialect/Byre/ByreDialect.h" #include "byteir/Dialect/Byre/Common.h" #include "byteir/Utils/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" @@ -86,6 +88,53 @@ class ConvertCallOpToByreTensorPattern : public OpRewritePattern { bool appendArgTypes; }; +class ConvertGPULaunchFuncToByrePattern + : public OpRewritePattern { + +public: + ConvertGPULaunchFuncToByrePattern(MLIRContext *ctx, bool useBarePtrCallConv) + : OpRewritePattern(ctx), + useBarePtrCallConv(useBarePtrCallConv) {} + + LogicalResult matchAndRewrite(gpu::LaunchFuncOp launchOp, + PatternRewriter &rewriter) const override { + auto computeOp = rewriter.create( + launchOp->getLoc(), TypeRange(), "PTXOp", launchOp.getKernelOperands(), + /*memEffects*/ ArrayAttr()); + + computeOp->setAttr( + rewriter.getStringAttr("kernel_name"), + rewriter.getStringAttr(launchOp.getKernelName().getValue())); + + auto grid = launchOp.getGridSizeOperandValues(); + int64_t gx = cast(grid.x.getDefiningOp()).value(); + int64_t gy = cast(grid.y.getDefiningOp()).value(); + int64_t gz = cast(grid.z.getDefiningOp()).value(); + computeOp->setAttr("GridSize.x", rewriter.getI32IntegerAttr(gx)); + computeOp->setAttr("GridSize.y", rewriter.getI32IntegerAttr(gy)); + computeOp->setAttr("GridSize.z", rewriter.getI32IntegerAttr(gz)); + + auto block = launchOp.getBlockSizeOperandValues(); + int64_t bx = cast(block.x.getDefiningOp()).value(); + int64_t by = cast(block.y.getDefiningOp()).value(); + int64_t bz = cast(block.z.getDefiningOp()).value(); + computeOp->setAttr("BlockSize.x", rewriter.getI32IntegerAttr(bx)); + computeOp->setAttr("BlockSize.y", rewriter.getI32IntegerAttr(by)); + computeOp->setAttr("BlockSize.z", rewriter.getI32IntegerAttr(bz)); + + if (useBarePtrCallConv) { + computeOp->setAttr(byre::getKernelCallConventionAttrName(), + rewriter.getStringAttr("bare_ptr")); + } + rewriter.eraseOp(launchOp); + + return success(); + } + +private: + bool useBarePtrCallConv; +}; + struct ConvertFuncToByreTensorPass : public ConvertFuncToByreTensorBase { public: @@ -104,6 +153,24 @@ struct ConvertFuncToByreTensorPass } } }; + +struct ConvertGPULaunchFuncToByrePass + : public ConvertGPULaunchFuncToByreBase { +public: + ConvertGPULaunchFuncToByrePass(bool useBarePtrCallConv) + : ConvertGPULaunchFuncToByreBase() { + this->useBarePtrCallConv = useBarePtrCallConv; + } + void runOnOperation() override { + MLIRContext &ctx = getContext(); + RewritePatternSet patterns(&ctx); + populateGPULaunchFuncToByrePattern(patterns, useBarePtrCallConv); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; }; // namespace void mlir::populateFuncToByreTensorPattern(RewritePatternSet &patterns, @@ -112,7 +179,18 @@ void mlir::populateFuncToByreTensorPattern(RewritePatternSet &patterns, appendArgTypes); } +void mlir::populateGPULaunchFuncToByrePattern(RewritePatternSet &patterns, + bool useBarePtrCallConv) { + patterns.add(patterns.getContext(), + useBarePtrCallConv); +} + std::unique_ptr> mlir::createConvertFuncToByreTensorPass(bool appendArgTypes) { return std::make_unique(appendArgTypes); } + +std::unique_ptr +mlir::createConvertGPULaunchFuncToByrePass(bool useBarePtrCallConv) { + return std::make_unique(useBarePtrCallConv); +} \ No newline at end of file diff --git a/compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp b/compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp index e78709194..452075ce4 100644 --- a/compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp +++ b/compiler/lib/Conversion/HloToCat/FuseHloToCat.cpp @@ -283,6 +283,48 @@ struct ConvertBmmReshapeTransposeToBmmReshape } }; +// bmm_rrr(x, broadcast_in_dim(y)) => reshape(gemm_rrr(reshape(x), y)) +struct ConvertBmmRRRBroadcastToReshapeGemmRRRReshape + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(cat::BMMRRROp op, + PatternRewriter &rewriter) const override { + auto bCastOp = op.getRhs().getDefiningOp(); + if (!bCastOp) { + return failure(); + } + auto lhsType = op.getLhs().getType().cast(); + auto rhsType = op.getRhs().getType().cast(); + if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) { + return failure(); + } + SmallVector broadcastDimensions; + getValuesFromDenseIntElementsAttr(bCastOp.getBroadcastDimensions(), + broadcastDimensions); + if (broadcastDimensions.size() != 2) { + return failure(); + } + if (broadcastDimensions[0] != 1 || broadcastDimensions[1] != 2) { + return failure(); + } + + RankedTensorType firstReshapeType = RankedTensorType::get( + {lhsType.getDimSize(0) * lhsType.getDimSize(1), lhsType.getDimSize(2)}, + lhsType.getElementType()); + RankedTensorType gemmType = RankedTensorType::get( + {firstReshapeType.getDimSize(0), rhsType.getDimSize(2)}, + lhsType.getElementType()); + auto firstReshape = rewriter.create( + op.getLoc(), firstReshapeType, op.getLhs()); + auto gemm = rewriter.create( + op.getLoc(), gemmType, firstReshape, bCastOp.getOperand()); + auto secondReshape = + rewriter.create(op.getLoc(), op.getType(), gemm); + rewriter.replaceOp(op, secondReshape); + return success(); + } +}; + struct FuseMhloToCatPass : public FuseMhloToCatBase { public: FuseMhloToCatPass() = default; @@ -317,7 +359,8 @@ void populateFuseMhloToCatPattern(RewritePatternSet &patterns) { ConvertBmmReshapeTransposeToBmmReshape, ConvertBmmReshapeTransposeToBmmReshape, ConvertBmmReshapeTransposeToBmmReshape, - ConvertBmmReshapeTransposeToBmmReshape + ConvertBmmReshapeTransposeToBmmReshape, + ConvertBmmRRRBroadcastToReshapeGemmRRRReshape >(patterns.getContext()); // clang-format on } diff --git a/compiler/lib/Conversion/HloToCat/FuseHloToCatPattern.td b/compiler/lib/Conversion/HloToCat/FuseHloToCatPattern.td index c984993e2..ba101c1af 100644 --- a/compiler/lib/Conversion/HloToCat/FuseHloToCatPattern.td +++ b/compiler/lib/Conversion/HloToCat/FuseHloToCatPattern.td @@ -30,6 +30,7 @@ def OneRank : Constraint().getRank() == 1"> def TwoRank : Constraint().getRank() == 2">, "two rank">; def ThreeRank : Constraint().getRank() == 3">, "three rank">; def FourRank : Constraint().getRank() == 4">, "four rank">; +def Permute10Check : Constraint()[0] == 1 && $0.getValues()[1] == 0">, "transpose <[1, 0]>">; def Permute021Check : Constraint()[0] == 0 && $0.getValues()[1] == 2 && $0.getValues()[2] == 1">, "bmm 3d permute check (for transpose before bmm)">; def Permute0213Check : Constraint()[0] == 0 && $0.getValues()[1] == 2 && $0.getValues()[2] == 1 && $0.getValues()[3] == 3">, "bmm 4d permute check (for transpose after bmm)">; def TransposeCheck : Constraint()[0] == 1 && $0.getValues()[1] == 0">, "matrix transpose check">; @@ -116,6 +117,26 @@ def MhloCatGemmRCRTransToCatGemmRCRPermutePattern (Cat_GemmRCRPermuteOp $lhs, $rhs, (getDim1Attr $reshape_out), (getDim2Attr $reshape_out)), [(TwoRank $lhs), (TwoRank $rhs), (FourRank $reshape_out), (GemmPermuteShapeCheck $reshape_out, $gemm_out), (Permute0213Check $permute)]>; +def MhloCatGemmRRRTransToCatGemmRRRPermutePattern + : Pat<(MHLO_TransposeOp + (MHLO_ReshapeOp : $reshape_out + (Cat_GemmRRROp : $gemm_out + $lhs, $rhs) + ), + $permute), + (Cat_GemmRRRPermuteOp $lhs, $rhs, (getDim1Attr $reshape_out), (getDim2Attr $reshape_out)), + [(TwoRank $lhs), (TwoRank $rhs), (FourRank $reshape_out), (GemmPermuteShapeCheck $reshape_out, $gemm_out), (Permute0213Check $permute)]>; + +def CatGemmRRRPermuteTransToCatGemmRCRPermutePattern + : Pat<(Cat_GemmRRRPermuteOp + $lhs, + (MHLO_TransposeOp $rhs, $permute), + $t1, + $t2 + ), + (Cat_GemmRCRPermuteOp $lhs, $rhs, $t1, $t2), + [(TwoSize $permute), (Permute10Check $permute)]>; + def LayoutFrom3DDotGeneralDimNums : NativeCodeCall<"GetLayoutFrom3DDotGeneralDimNums($0, &$_builder)">; def CheckRRRLayoutFrom3DDotGeneralDimNums @@ -140,6 +161,12 @@ def CheckCCRLayoutFrom3DDotGeneralDimNums CPred<"$0.getLhsContractingDimensions().size() == 1 && $0.getRhsContractingDimensions().size() == 1 && $0.getLhsContractingDimensions()[0] == 1 && $0.getRhsContractingDimensions()[0] == 2">, "is bmm ccr dimension">; +def CheckBMMPermuteShapeSplitOnBatch + : Constraint< + CPred<"$0.getType().cast().getShape()[0] * $0.getType().cast().getShape()[1] == $1.getType().cast().getShape()[0] && $0.getType().cast().getShape()[2] == $1.getType().cast().getShape()[1]">, + "bmm rrr Shape Split On Batch">; + + def MhloDotGeneralReshapeTransposeToBMMRRRPermutePattern : Pat<(MHLO_TransposeOp (MHLO_ReshapeOp : $reshape_out @@ -153,7 +180,8 @@ def MhloDotGeneralReshapeTransposeToBMMRRRPermutePattern (ThreeRank $rhs), (FourSize $permute), (Permute0213Check $permute), - (CheckRRRLayoutFrom3DDotGeneralDimNums $dimension_numbers) + (CheckRRRLayoutFrom3DDotGeneralDimNums $dimension_numbers), + (CheckBMMPermuteShapeSplitOnBatch $reshape_out, $lhs) ]>; def MhloDotGeneralReshapeTransposeToBMMRCRPermutePattern diff --git a/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp b/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp index 49676e2b1..f759d2d0b 100644 --- a/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp +++ b/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp @@ -69,6 +69,27 @@ class ConvertViewOpToByrePattern : public OpConversionPattern { } }; +class ConvertSubViewOpToByrePattern + : public OpConversionPattern { +public: + ConvertSubViewOpToByrePattern(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + + LogicalResult + matchAndRewrite(memref::SubViewOp op, memref::SubViewOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getType().getLayout().isIdentity()) + return failure(); + + if (!op.getSource().getType().getLayout().isIdentity()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + adaptor.getSource(), 0); + return success(); + } +}; + class ConvertMemrefCopyOpToByrePattern : public OpConversionPattern { public: @@ -174,8 +195,8 @@ void mlir::populateMemrefToByrePattern(RewritePatternSet &patterns) { patterns.add, - ConvertReshapeLikeOpToByrePattern>( - patterns.getContext()); + ConvertReshapeLikeOpToByrePattern, + ConvertSubViewOpToByrePattern>(patterns.getContext()); } std::unique_ptr> diff --git a/compiler/lib/Conversion/ToByre/ToByre.cpp b/compiler/lib/Conversion/ToByre/ToByre.cpp index 1c83a25b3..31e43449c 100644 --- a/compiler/lib/Conversion/ToByre/ToByre.cpp +++ b/compiler/lib/Conversion/ToByre/ToByre.cpp @@ -1074,10 +1074,11 @@ static bool isRewritablePrivateFunc(func::FuncOp func) { } // identify EntryPoint funciton -static void identifyEntryPointFuncAndCalls( - ModuleOp m, llvm::SmallVector &entries, - llvm::SmallVector &calls, - llvm::SmallVector &removeFuncs) { +static void +identifyEntryPointFuncAndCalls(ModuleOp m, + llvm::SmallVector &entries, + llvm::SmallVector &calls, + llvm::SetVector &removeFuncs) { // get first entry func llvm::SmallPtrSet callSet; @@ -1094,7 +1095,7 @@ static void identifyEntryPointFuncAndCalls( if (isRewritablePrivateFunc(calleeFuncOp) && !callSet.contains(callOp)) { calls.push_back(callOp); callSet.insert(callOp); - removeFuncs.push_back(calleeFuncOp); + removeFuncs.insert(calleeFuncOp); } } } @@ -1273,7 +1274,7 @@ void ConvertFuncAndCallToByrePass::runOnOperation() { MLIRContext &ctx = getContext(); llvm::SmallVector entryCollector; llvm::SmallVector callCollector; - llvm::SmallVector removeFuncCollector; + llvm::SetVector removeFuncCollector; identifyEntryPointFuncAndCalls(m, entryCollector, callCollector, removeFuncCollector); @@ -1330,7 +1331,7 @@ void ConvertFuncAndCallToByrePass::runOnOperation() { return signalPassFailure(); } - for (auto func : removeFuncCollector) { + for (auto func : removeFuncCollector.takeVector()) { func->erase(); } } diff --git a/compiler/lib/Conversion/ToLinalg/MemrefCopyToLinalg.cpp b/compiler/lib/Conversion/ToLinalg/MemrefCopyToLinalg.cpp index 8ceae81d9..2fe676adb 100644 --- a/compiler/lib/Conversion/ToLinalg/MemrefCopyToLinalg.cpp +++ b/compiler/lib/Conversion/ToLinalg/MemrefCopyToLinalg.cpp @@ -39,8 +39,9 @@ namespace { struct MemrefCopyOpToLinalg : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; MemrefCopyOpToLinalg(MLIRContext *ctx, std::string anchorTag, - std::string attachAttr) - : OpRewritePattern(ctx), anchorTag(anchorTag), attachAttr(attachAttr) {} + std::string attachAttr, bool outlining) + : OpRewritePattern(ctx), anchorTag(anchorTag), attachAttr(attachAttr), + outlining(outlining) {} LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override { @@ -56,84 +57,101 @@ struct MemrefCopyOpToLinalg : public OpRewritePattern { auto dstType = llvm::dyn_cast(dst.getType()); if (!srcType || !dstType) return failure(); - if (srcType.getLayout().isIdentity() && dstType.getLayout().isIdentity()) - return failure(); - SmallVector ops; - auto getViewSource = [&](Value value) { - while (auto viewOp = value.getDefiningOp()) { - ops.push_back(viewOp); - value = viewOp.getViewSource(); + if (outlining) { + if (srcType.getLayout().isIdentity() && dstType.getLayout().isIdentity()) + return failure(); + + SmallVector ops; + auto getViewSource = [&](Value value) { + while (auto viewOp = value.getDefiningOp()) { + ops.push_back(viewOp); + value = viewOp.getViewSource(); + } + return value; + }; + Value callSrc = getViewSource(src); + Value callDst = getViewSource(dst); + + auto symbolTableOp = SymbolTable::getNearestSymbolTable(copyOp); + SymbolTable symbolTable(symbolTableOp); + auto funcType = + rewriter.getFunctionType({callSrc.getType(), callDst.getType()}, {}); + + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(parentOp); + func::FuncOp funcOp = rewriter.create( + copyOp->getLoc(), "memref_copy_kernel", funcType); + symbolTable.insert(funcOp); + funcOp.setPrivate(); + + Block *entryBlock = funcOp.addEntryBlock(); + rewriter.setInsertionPointToStart(entryBlock); + IRMapping mapping; + mapping.map(ValueRange{callSrc, callDst}, entryBlock->getArguments()); + for (auto &&op : llvm::reverse(ops)) { + auto newOp = rewriter.clone(*op, mapping); + mapping.map(op, newOp); + } + AffineMap id = AffineMap::getMultiDimIdentityMap(dstType.getRank(), + rewriter.getContext()); + SmallVector iteratorTypes( + dstType.getRank(), utils::IteratorType::parallel); + rewriter.create( + copyOp->getLoc(), mapping.lookup(copyOp.getSource()), + mapping.lookup(copyOp.getTarget()), llvm::ArrayRef({id, id}), + iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args.front()); + }, + copyOp->getAttrs()); + rewriter.create(copyOp->getLoc()); + if (!attachAttr.empty()) { + funcOp->setAttr(attachAttr, rewriter.getUnitAttr()); } - return value; - }; - Value callSrc = getViewSource(src); - Value callDst = getViewSource(dst); - - auto symbolTableOp = SymbolTable::getNearestSymbolTable(copyOp); - SymbolTable symbolTable(symbolTableOp); - auto funcType = - rewriter.getFunctionType({callSrc.getType(), callDst.getType()}, {}); - - OpBuilder::InsertionGuard guard(rewriter); - // Insert before module terminator. - rewriter.setInsertionPoint(parentOp); - func::FuncOp funcOp = rewriter.create( - copyOp->getLoc(), "memref_copy_kernel", funcType); - symbolTable.insert(funcOp); - funcOp.setPrivate(); - - Block *entryBlock = funcOp.addEntryBlock(); - rewriter.setInsertionPointToStart(entryBlock); - IRMapping mapping; - mapping.map(ValueRange{callSrc, callDst}, entryBlock->getArguments()); - for (auto &&op : llvm::reverse(ops)) { - auto newOp = rewriter.clone(*op, mapping); - mapping.map(op, newOp); - } - AffineMap id = AffineMap::getMultiDimIdentityMap(dstType.getRank(), - rewriter.getContext()); - SmallVector iteratorTypes( - dstType.getRank(), utils::IteratorType::parallel); - rewriter.create( - copyOp->getLoc(), mapping.lookup(copyOp.getSource()), - mapping.lookup(copyOp.getTarget()), llvm::ArrayRef({id, id}), - iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args.front()); - }, - copyOp->getAttrs()); - rewriter.create(copyOp->getLoc()); - if (!attachAttr.empty()) { - funcOp->setAttr(attachAttr, rewriter.getUnitAttr()); - } - rewriter.setInsertionPoint(copyOp); - auto callOp = rewriter.replaceOpWithNewOp( - copyOp, funcOp, ValueRange{callSrc, callDst}); - callOp->setAttr(byre::getByreCallOpReadonlyOperandNumAttrName(), - rewriter.getIndexAttr(1)); + rewriter.setInsertionPoint(copyOp); + auto callOp = rewriter.replaceOpWithNewOp( + copyOp, funcOp, ValueRange{callSrc, callDst}); + callOp->setAttr(byre::getByreCallOpReadonlyOperandNumAttrName(), + rewriter.getIndexAttr(1)); + } else { + AffineMap id = AffineMap::getMultiDimIdentityMap(dstType.getRank(), + rewriter.getContext()); + SmallVector iteratorTypes( + dstType.getRank(), utils::IteratorType::parallel); + rewriter.replaceOpWithNewOp( + copyOp, src, dst, llvm::ArrayRef({id, id}), iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args.front()); + }, + copyOp->getAttrs()); + } return success(); } private: std::string anchorTag; std::string attachAttr; + bool outlining; }; struct MemrefCopyToLinalgPass : public MemrefCopyToLinalgPassBase { - MemrefCopyToLinalgPass(std::string anchorTag, std::string attachAttr) + MemrefCopyToLinalgPass(std::string anchorTag, std::string attachAttr, + bool outlining) : MemrefCopyToLinalgPassBase() { this->anchorTag = anchorTag; this->attachAttr = attachAttr; + this->outlining = outlining; } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(&getContext()); patterns.insert(context, this->anchorTag, - this->attachAttr); + this->attachAttr, this->outlining); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); @@ -144,8 +162,10 @@ struct MemrefCopyToLinalgPass } // namespace std::unique_ptr> -createMemrefCopyToLinalgPass(std::string anchorTag, std::string attachAttr) { - return std::make_unique(anchorTag, attachAttr); +createMemrefCopyToLinalgPass(std::string anchorTag, std::string attachAttr, + bool outlining) { + return std::make_unique(anchorTag, attachAttr, + outlining); } } // namespace mlir diff --git a/compiler/lib/Conversion/ToPTX/CollectGPUKernel.cpp b/compiler/lib/Conversion/ToPTX/CollectGPUKernel.cpp index d2556024f..7ccd11036 100644 --- a/compiler/lib/Conversion/ToPTX/CollectGPUKernel.cpp +++ b/compiler/lib/Conversion/ToPTX/CollectGPUKernel.cpp @@ -37,8 +37,10 @@ namespace { struct CollectGPUKernelPass : public CollectGPUKernelBase { - CollectGPUKernelPass(const std::string &name) : CollectGPUKernelBase() { + CollectGPUKernelPass(const std::string &name, bool removeHost) + : CollectGPUKernelBase() { this->moduleName = name; + this->removeHost = removeHost; } void runOnOperation() override { @@ -49,20 +51,20 @@ struct CollectGPUKernelPass bool found = false; GPUModuleOp dst; - for (auto &op : m.getBody()->without_terminator()) { - if (auto gm = dyn_cast(op)) { - if (gm.getName() == moduleName) { - found = true; - dst = gm; - } else { - gmCollector.push_back(gm); - } + for (auto gm : m.getOps()) { + if (gm.getName() == moduleName) { + found = true; + dst = gm; + } else { + gmCollector.push_back(gm); } } // Note FuncOps not in m.getBody()->without_terminator() - for (auto func : m.getOps()) { - removeOps.push_back(func); + if (removeHost) { + for (auto func : m.getOps()) { + removeOps.push_back(func); + } } if (gmCollector.size() == 0) { @@ -78,12 +80,13 @@ struct CollectGPUKernelPass } SymbolTable dstTable(dst); - for (auto gm : gmCollector) { for (auto &op : gm.getBody()->without_terminator()) { auto newOp = op.clone(); - dstTable.insert(newOp); + auto newName = dstTable.insert(newOp); + (void)SymbolTable::replaceAllSymbolUses(&op, newName, m); } + (void)SymbolTable::replaceAllSymbolUses(gm, dst.getNameAttr(), m); gm.erase(); } @@ -96,6 +99,6 @@ struct CollectGPUKernelPass } // namespace std::unique_ptr> -mlir::createCollectGPUKernelPass(const std::string &name) { - return std::make_unique(name); +mlir::createCollectGPUKernelPass(const std::string &name, bool removeHost) { + return std::make_unique(name, removeHost); } diff --git a/compiler/lib/Dialect/CMakeLists.txt b/compiler/lib/Dialect/CMakeLists.txt index fe905afe9..3e8627a7e 100644 --- a/compiler/lib/Dialect/CMakeLists.txt +++ b/compiler/lib/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(Affine) add_subdirectory(Byre) add_subdirectory(Cat) add_subdirectory(Ccl) +add_subdirectory(GPU) add_subdirectory(Lace) add_subdirectory(Linalg) add_subdirectory(MemRef) diff --git a/compiler/lib/Dialect/Cat/IR/CatDialect.cpp b/compiler/lib/Dialect/Cat/IR/CatDialect.cpp index 24570d4cc..966a22e22 100644 --- a/compiler/lib/Dialect/Cat/IR/CatDialect.cpp +++ b/compiler/lib/Dialect/Cat/IR/CatDialect.cpp @@ -158,3 +158,9 @@ LogicalResult GemmRCRPermuteOp::verify() { this->getOutput(), this->getT1(), this->getT2(), "rcr"); } + +LogicalResult GemmRRRPermuteOp::verify() { + return VerifyGemmPermute0213Layout(this->getLhs(), this->getRhs(), + this->getOutput(), this->getT1(), + this->getT2(), "rrr"); +} diff --git a/compiler/lib/Dialect/GPU/CMakeLists.txt b/compiler/lib/Dialect/GPU/CMakeLists.txt new file mode 100644 index 000000000..5c919f7df --- /dev/null +++ b/compiler/lib/Dialect/GPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) \ No newline at end of file diff --git a/compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..733282ba9 --- /dev/null +++ b/compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(ByteIRGPUPasses + ShmAllocaToWorkgroupArg.cpp + + ADDITIONAL_HEADER_DIRS + ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Dialect/GPU + ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Dialect/GPU/Transforms + + DEPENDS + ByteIRGPUPassIncGen + ByteIRUtils + MLIRGPUDialect + + LINK_LIBS PUBLIC + ByteIRUtils + MLIRIR + MLIRGPUDialect + MLIRMemRefDialect + MLIRSupport +) diff --git a/compiler/lib/Dialect/GPU/Transforms/ShmAllocaToWorkgroupArg.cpp b/compiler/lib/Dialect/GPU/Transforms/ShmAllocaToWorkgroupArg.cpp new file mode 100644 index 000000000..808ed797f --- /dev/null +++ b/compiler/lib/Dialect/GPU/Transforms/ShmAllocaToWorkgroupArg.cpp @@ -0,0 +1,86 @@ +//===- ShmAllocaToWorkgroupArg.cpp --------------------------------- C++ +//-*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/GPU/Passes.h" +#include "byteir/Dialect/GPU/Transforms/Transforms.h" +#include "byteir/Transforms/MemoryPlanning.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include + +#define DEBUG_TYPE "shm-alloca-to-workgroup-arg" + +namespace mlir { +#define GEN_PASS_DEF_SHMALLOCATOWORKGROUPARG +#include "byteir/Dialect/GPU/Passes.h.inc" +} // namespace mlir + +using namespace llvm; +using namespace mlir; + +namespace { +struct ShmAllocaToWorkgroupArgPass + : public impl::ShmAllocaToWorkgroupArgBase { + void runOnOperation() override { + gpu::GPUModuleOp m = getOperation(); + WalkResult walkResult = m->walk([&](gpu::GPUFuncOp func) { + if (!func.isKernel()) + return WalkResult::advance(); + + // OpPassManager pm(func.getOperationName()); + // pm.addPass(createMemoryPlanningPass(/* alignment */ 1, /* alloca */ + // true, + // /* memory space */ 0, + // /* callback */ nullptr)); + // if (mlir::failed(runPipeline(pm, func))) { + // return WalkResult::interrupt(); + // } + + gpu::hoistShmAllocaToWorkgroup(func); + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) { + m->emitError() << "ShmAllocaToWorkgroupArgPass failed"; + signalPassFailure(); + } + } +}; +} // namespace + +void mlir::gpu::hoistShmAllocaToWorkgroup(gpu::GPUFuncOp func) { + func->walk([&](memref::AllocaOp alloca) { + auto memref = alloca.getType(); + if (auto memorySpace = llvm::dyn_cast_or_null( + memref.getMemorySpace())) { + if (memorySpace.getValue() == + gpu::GPUDialect::getWorkgroupAddressSpace()) { + Value workgroup = func.addWorkgroupAttribution(memref, alloca.getLoc()); + alloca.getMemref().replaceAllUsesWith(workgroup); + alloca->erase(); + } + } + }); +} diff --git a/compiler/lib/Dialect/Linalg/TransformOps/LinalgExtTransformOps.cpp b/compiler/lib/Dialect/Linalg/TransformOps/LinalgExtTransformOps.cpp index 64e456af3..bcb8df731 100644 --- a/compiler/lib/Dialect/Linalg/TransformOps/LinalgExtTransformOps.cpp +++ b/compiler/lib/Dialect/Linalg/TransformOps/LinalgExtTransformOps.cpp @@ -56,6 +56,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" @@ -63,6 +64,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" + #include using namespace mlir; @@ -150,6 +152,76 @@ transform::CollapseDimsOp::apply(transform::TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// DetensorizeOp +//===----------------------------------------------------------------------===// +namespace { +LogicalResult detensorizeLinalgOp(OpBuilder &b, linalg::LinalgOp linalgOp) { + if (!linalgOp.hasTensorSemantics()) + return failure(); + + if (linalgOp.getNumLoops()) + return failure(); + + Location loc = linalgOp->getLoc(); + SmallVector scalars; + scalars.reserve(linalgOp->getNumOperands()); + for (auto &&operand : linalgOp->getOpOperands()) { + if (!linalgOp.payloadUsesValueFromOperand(&operand)) { + scalars.push_back(nullptr); + continue; + } + if (linalgOp.isScalar(&operand)) { + scalars.push_back(operand.get()); + continue; + } + auto tensorType = llvm::dyn_cast(operand.get().getType()); + if (!tensorType || !tensorType.hasRank() || tensorType.getRank() != 0) + return failure(); + + scalars.push_back( + b.create(loc, operand.get(), ValueRange())); + } + + Block *body = linalgOp.getBlock(); + IRMapping map; + map.map(body->getArguments(), scalars); + for (auto &&op : body->without_terminator()) { + b.clone(op, map); + } + + for (auto &&opOperand : linalgOp.getDpsInitOperands()) { + OpOperand *yieldOperand = linalgOp.getMatchingYieldValue(opOperand); + Value element = map.lookupOrDefault(yieldOperand->get()); + Value tensor = b.create( + loc, RankedTensorType::get({}, element.getType()), ValueRange(element)); + Value result = linalgOp.getTiedOpResult(opOperand); + result.replaceAllUsesWith(tensor); + } + linalgOp->erase(); + return success(); +} +} // namespace + +DiagnosedSilenceableFailure +transform::DetensorizeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + for (Operation *target : state.getPayloadOps(getTarget())) { + auto linalgOp = dyn_cast_or_null(target); + if (!linalgOp) + return emitDefaultDefiniteFailure(target) + << " detensorize transformation should be applied on linalg op"; + + OpBuilder builder(getContext()); + builder.setInsertionPoint(target); + if (failed(detensorizeLinalgOp(builder, linalgOp))) + return emitDefaultDefiniteFailure(linalgOp) + << " failed to detensorize op"; + } + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // replace unit extent dims //===----------------------------------------------------------------------===// @@ -1498,6 +1570,66 @@ LogicalResult transform::FuseOperandsOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// InsertSliceToCopyExtOp +//===----------------------------------------------------------------------===// +template +DiagnosedSilenceableFailure +insertSliceToCopyImpl(RewriterBase &rewriter, OpTy target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + static_assert(llvm::is_one_of() && + "wrong op type"); + + if (auto copySource = + target.getSource().template getDefiningOp()) { + results.push_back(copySource); + return DiagnosedSilenceableFailure::success(); + } + + // If we are inside an InParallel region, temporarily set the insertion point + // outside: only tensor.parallel_insert_slice ops are allowed in there. + if constexpr (std::is_same_v) { + rewriter.setInsertionPoint( + target->template getParentOfType()); + } + + Value extracted = rewriter.create( + target.getLoc(), target.getSourceType(), target.getDest(), + target.getMixedOffsets(), target.getMixedSizes(), + target.getMixedStrides()); + Value copied = rewriter + .create(target.getLoc(), + target.getSource(), extracted) + .getResult(0); + // Reset the insertion point. + rewriter.setInsertionPoint(target); + rewriter.replaceOpWithNewOp( + target, copied, target.getDest(), target.getMixedOffsets(), + target.getMixedSizes(), target.getMixedStrides()); + + results.push_back(copied.getDefiningOp()); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure transform::InsertSliceToCopyExtOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *targetOp, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(targetOp); + if (auto target = dyn_cast(targetOp)) + return insertSliceToCopyImpl(rewriter, target, results, state); + if (auto target = dyn_cast(targetOp)) + return insertSliceToCopyImpl(rewriter, target, results, state); + + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "only InsertSliceOp and ParallelInsertSliceOp ops are supported"; + diag.attachNote(targetOp->getLoc()) << "target op"; + return diag; +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/compiler/lib/Dialect/Linalg/Transforms/FuseElementwise.cpp b/compiler/lib/Dialect/Linalg/Transforms/FuseElementwise.cpp index 1b874e9d2..fc0af1829 100644 --- a/compiler/lib/Dialect/Linalg/Transforms/FuseElementwise.cpp +++ b/compiler/lib/Dialect/Linalg/Transforms/FuseElementwise.cpp @@ -544,9 +544,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, else return isProjectedPermutationAndAllowConst(map); }) && - genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > - 0 && - llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator); + genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0; } class ExpansionInfo { @@ -568,6 +566,9 @@ class ExpansionInfo { ArrayRef getExpandedShapeOfDim(unsigned i) const { return expandedShapeMap[i]; } + ArrayRef getIteratorTypes() const { + return iteratorTypes; + } ArrayRef getOriginalShape() const { return originalLoopExtent; } private: @@ -579,6 +580,8 @@ class ExpansionInfo { SmallVector> expandedShapeMap; /// Extent of the loop in the original operation. SmallVector originalLoopExtent; + /// Parallel types of the expanded loops + SmallVector iteratorTypes; unsigned expandedOpNumDims; }; @@ -591,6 +594,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, if (reassociationMaps.empty()) return failure(); AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); + auto origIteratorTypes = linalgOp.getIteratorTypesArray(); SmallVector originalLoopRange = linalgOp.getStaticLoopRanges(); originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); @@ -621,8 +625,11 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, auto seq = llvm::seq(sum, sum + numFoldedDim.value()); reassociation.emplace_back(seq.begin(), seq.end()); sum += numFoldedDim.value(); + iteratorTypes.append(numFoldedDim.value(), + origIteratorTypes[numFoldedDim.index()]); } expandedOpNumDims = sum; + return success(); } @@ -871,15 +878,11 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, } } - // The iterator types of the expanded op are all parallel. - SmallVector iteratorTypes( - expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel); - TypeRange resultTypes = ValueRange(outputs).getTypes(); - auto fusedOp = - rewriter.create(genericOp.getLoc(), resultTypes, - /*inputs=*/expandedOpOperands, outputs, - expandedOpIndexingMaps, iteratorTypes); + auto fusedOp = rewriter.create(genericOp.getLoc(), resultTypes, + /*inputs=*/expandedOpOperands, + outputs, expandedOpIndexingMaps, + expansionInfo.getIteratorTypes()); Region &fusedRegion = fusedOp->getRegion(0); Region &originalRegion = genericOp->getRegion(0); rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); diff --git a/compiler/lib/Dialect/Linalg/Transforms/LinalgCollapseLoops.cpp b/compiler/lib/Dialect/Linalg/Transforms/LinalgCollapseLoops.cpp index defd678af..081599f13 100644 --- a/compiler/lib/Dialect/Linalg/Transforms/LinalgCollapseLoops.cpp +++ b/compiler/lib/Dialect/Linalg/Transforms/LinalgCollapseLoops.cpp @@ -64,11 +64,12 @@ namespace { /// dimensions. It only applies these to "parallel" loops without mixing them /// with "reduction" types. static SmallVector -getCollapsibleLoops(linalg::GenericOp genericOp) { +getCollapsibleLoops(linalg::GenericOp genericOp, + utils::IteratorType iteratorType) { SmallVector contiguousLoops; SmallVector pDims; - genericOp.getParallelDims(pDims); + findPositionsOfType(genericOp.getIteratorTypesArray(), iteratorType, pDims); if (pDims.size() < 2) return contiguousLoops; @@ -76,15 +77,18 @@ getCollapsibleLoops(linalg::GenericOp genericOp) { auto hasAllMapsSameSequence = [&](AffineExpr preExpr, AffineExpr nextExpr) { for (AffineMap map : genericOp.getIndexingMapsArray()) { - bool foundSeq = false; - for (auto [index, resultExpr] : llvm::enumerate(map.getResults())) { - if (resultExpr == nextExpr) { - foundSeq = (index > 0 && preExpr == map.getResult(index - 1)); - break; - } + auto prePos = map.getResultPosition(preExpr); + auto nextPos = map.getResultPosition(nextExpr); + if (!prePos.has_value()) { + if (nextPos.has_value()) + return false; + } else { + if (!nextPos.has_value()) + return false; + + if (prePos.value() + 1 != nextPos.value()) + return false; } - if (!foundSeq) - return false; } return true; }; @@ -519,13 +523,17 @@ FailureOr> collapseGenericOpIterationDimsEx( class CollapseLoopsOnGenericOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + CollapseLoopsOnGenericOp(MLIRContext *context, + utils::IteratorType iteratorType) + : OpRewritePattern(context), iteratorType(iteratorType) {} + LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { // Collect collapsible loops // TODO: All rules come from iree project, add our own if (!isEligibleForCollapse(op)) return failure(); - auto loops = getCollapsibleLoops(op); + auto loops = getCollapsibleLoops(op, iteratorType); if (loops.empty()) return failure(); @@ -542,22 +550,31 @@ class CollapseLoopsOnGenericOp : public OpRewritePattern { rewriter.replaceOp(op, *replacements); return success(); } + +private: + utils::IteratorType iteratorType; }; struct LinalgCollapseLoopsPass : public impl::LinalgCollapseLoopsBase { + LinalgCollapseLoopsPass(utils::IteratorType iteratorType) + : LinalgCollapseLoopsBase() { + this->iteratorType = iteratorType; + } + void runOnOperation() override { auto op = getOperation(); auto context = op->getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context, iteratorType); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } }; } // namespace -std::unique_ptr> mlir::createLinalgCollapseLoops() { - return std::make_unique(); +std::unique_ptr> +mlir::createLinalgCollapseLoops(utils::IteratorType iteratorType) { + return std::make_unique(iteratorType); } diff --git a/compiler/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/compiler/lib/Dialect/Tensor/Transforms/CMakeLists.txt index 524713f46..47cd8bab0 100644 --- a/compiler/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(ByteIRTensorPasses CanonicalizeExt.cpp + TensorPadSpecialization.cpp ADDITIONAL_HEADER_DIRS ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Dialect/mhlo @@ -8,9 +9,11 @@ add_mlir_dialect_library(ByteIRTensorPasses DEPENDS ByteIRUtils + ByteIRTensorPassIncGen LINK_LIBS PUBLIC MLIRIR MLIRSupport + MLIRSCFDialect ByteIRUtils ) \ No newline at end of file diff --git a/compiler/lib/Dialect/Tensor/Transforms/CanonicalizeExt.cpp b/compiler/lib/Dialect/Tensor/Transforms/CanonicalizeExt.cpp index 2feff7437..c35039404 100644 --- a/compiler/lib/Dialect/Tensor/Transforms/CanonicalizeExt.cpp +++ b/compiler/lib/Dialect/Tensor/Transforms/CanonicalizeExt.cpp @@ -27,6 +27,7 @@ #include "byteir/Utils/AttrUtils.h" #include "byteir/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" @@ -121,6 +122,46 @@ struct RankReducedExtractSliceCollapseShape return success(); } }; + +/// Fold zero rank from_elements + insert_slice into insert +/// +/// Example: +/// +/// %0 = tensor.from_elements %scalar : tensor +/// %1 = tensor.insert_slice %0 into %1[%c256] : tensor into +/// tensor<1024xf32> +/// +/// will be folded into +/// +/// %0 = tensor.insert %scalar into %1[%c256] : tensor<1024xf32> +struct FoldZeroRankFromElementsInsertSlice + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + auto fromElementsOp = + insertSliceOp.getSource().getDefiningOp(); + if (!fromElementsOp) + return failure(); + + RankedTensorType tensorType = insertSliceOp.getSourceType(); + if (tensorType.getRank() != 0) + return failure(); + + auto elements = fromElementsOp.getElements(); + if (elements.size() != 1) + return failure(); + + SmallVector indices = getValueOrCreateConstantIndexOp( + rewriter, insertSliceOp->getLoc(), + getMixedValues(insertSliceOp.getStaticOffsets(), + insertSliceOp.getOffsets(), rewriter)); + rewriter.replaceOpWithNewOp( + insertSliceOp, elements[0], insertSliceOp.getDest(), indices); + return success(); + } +}; } // namespace void mlir::tensor::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, @@ -132,6 +173,7 @@ void mlir::tensor::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, } patterns.add(ctx); + patterns.add(ctx); } void mlir::tensor::getCanonicalizationExtPatterns(RewritePatternSet &patterns, diff --git a/compiler/lib/Dialect/Tensor/Transforms/PassDetail.h b/compiler/lib/Dialect/Tensor/Transforms/PassDetail.h new file mode 100644 index 000000000..4214a74dc --- /dev/null +++ b/compiler/lib/Dialect/Tensor/Transforms/PassDetail.h @@ -0,0 +1,40 @@ +//===- PassDetail.h -------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H +#define BYTEIR_DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H + +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" + +// forward dialects for conversions +namespace mlir { + +namespace scf { +class SCFDialect; +} // namespace scf + +namespace tensor { +class TensorDialect; +} // namespace tensor + +#define GEN_PASS_CLASSES +#include "byteir/Dialect/Tensor/Passes.h.inc" + +} // namespace mlir + +#endif // BYTEIR_DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H diff --git a/compiler/lib/Dialect/Tensor/Transforms/TensorPadSpecialization.cpp b/compiler/lib/Dialect/Tensor/Transforms/TensorPadSpecialization.cpp new file mode 100644 index 000000000..43f750233 --- /dev/null +++ b/compiler/lib/Dialect/Tensor/Transforms/TensorPadSpecialization.cpp @@ -0,0 +1,242 @@ +//===- TensorPadSpecialization.cpp ---------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/Tensor/Transforms/TensorPadSpecialization.h" +#include "byteir/Utils/AttrUtils.h" +#include "byteir/Utils/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" + +#include "./PassDetail.h" + +#define DEBUG_TYPE "tensor-pad-specialization" + +using namespace mlir; + +namespace { +static LogicalResult +resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, + tensor::CollapseShapeOp collapseShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + int64_t cnt = 0; + SmallVector tmp(indices.size()); + SmallVector dynamicIndices; + for (ArrayRef groups : collapseShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + dynamicIndices.push_back(indices[cnt++]); + int64_t groupSize = groups.size(); + + // Calculate suffix product for all collapse op source dimension sizes. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + + // Derive the index values along all dimensions of the source corresponding + // to the index wrt to collapsed shape op output. + auto d0 = rewriter.getAffineDimExpr(0); + SmallVector delinearizingExprs = delinearize(d0, suffixProduct); + + // Construct the AffineApplyOp for each delinearizingExpr. + for (int64_t i = 0; i < groupSize; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, + delinearizingExprs[i]), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + dynamicIndices.clear(); + } + if (collapseShapeOp.getReassociationIndices().empty()) { + auto zeroAffineMap = rewriter.getConstantAffineMap(0); + int64_t srcRank = + cast(collapseShapeOp.getSrc().getType()).getRank(); + for (int64_t i = 0; i < srcRank; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, zeroAffineMap, dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + } + return success(); +} + +struct FoldExtractOfCollapseShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const { + auto collapseShapeOp = + extractOp.getTensor().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + + SmallVector indices(extractOp.getIndices().begin(), + extractOp.getIndices().end()); + SmallVector sourceIndices; + if (failed(resolveSourceIndicesCollapseShape(extractOp->getLoc(), rewriter, + collapseShapeOp, indices, + sourceIndices))) + return failure(); + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), collapseShapeOp.getSrc(), + sourceIndices); + return success(); + } +}; + +struct FoldExtractOfExtractSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + auto extractSliceOp = + extractOp.getTensor().getDefiningOp(); + if (!extractSliceOp) + return failure(); + + SmallVector indices(extractOp.getIndices().begin(), + extractOp.getIndices().end()); + SmallVector sourceIndices; + affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, extractOp->getLoc(), extractSliceOp.getMixedOffsets(), + extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), + indices, sourceIndices); + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), extractSliceOp.getSource(), + sourceIndices); + return success(); + } +}; + +struct FoldExtractOfPad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + auto padOp = extractOp.getTensor().getDefiningOp(); + if (!padOp) + return failure(); + + // Only constant padding value supported. + Value padValue = padOp.getConstantPaddingValue(); + if (!padValue) + return failure(); + + // Helper variables and functions for various arithmetic operations. These + // are used extensively for computing new offset/length and padding values. + Location loc = padOp->getLoc(); + AffineExpr dim0, dim1; + bindDims(rewriter.getContext(), dim0, dim1); + // Add two integers. + auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); + auto add = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, addMap, + {v1, v2}); + }; + // Subtract two integers. + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap, + {v1, v2}); + }; + + auto cmp = [&](OpFoldResult v1, OpFoldResult v2, + arith::CmpIPredicate pred) { + return rewriter.create( + loc, pred, getValueOrCreateConstantIndexOp(rewriter, loc, v1), + getValueOrCreateConstantIndexOp(rewriter, loc, v2)); + }; + + auto offsets = getAsOpFoldResult(extractOp.getIndices()); + SmallVector newOffsets; + Value inBound; + + int64_t rank = padOp.getSourceType().getRank(); + for (unsigned dim = 0; dim < rank; ++dim) { + auto low = padOp.getMixedLowPad()[dim]; + bool hasLowPad = !isConstantIntValue(low, 0); + auto offset = offsets[dim]; + auto srcSize = + tensor::getMixedSize(rewriter, loc, padOp.getSource(), dim); + + OpFoldResult newOffset = hasLowPad ? sub(offset, low) : offset; + newOffsets.push_back(newOffset); + auto lbcheck = cmp(low, offset, arith::CmpIPredicate::ule); + auto ubcheck = cmp(offset, hasLowPad ? add(low, srcSize) : srcSize, + arith::CmpIPredicate::ult); + auto check = rewriter.create(loc, lbcheck, ubcheck); + if (inBound) { + inBound = rewriter.create(loc, inBound, check); + } else { + inBound = check; + } + } + + rewriter.replaceOpWithNewOp( + extractOp, inBound, + [&](OpBuilder &b, Location loc) { + b.create( + loc, b.create( + loc, padOp.getSource(), + getValueOrCreateConstantIndexOp(b, loc, newOffsets)) + .getResult()); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, padValue); + }); + return success(); + } +}; + +struct TensorPadSpecializationPass + : public TensorPadSpecializationBase { + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr mlir::createTensorPadSpecializationPass() { + return std::make_unique(); +} diff --git a/compiler/lib/Dialect/Transform/Transforms/TransformInsertion.cpp b/compiler/lib/Dialect/Transform/Transforms/TransformInsertion.cpp index cf5ff1e69..be6485a85 100644 --- a/compiler/lib/Dialect/Transform/Transforms/TransformInsertion.cpp +++ b/compiler/lib/Dialect/Transform/Transforms/TransformInsertion.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/IR/Builders.h" @@ -89,6 +90,61 @@ void insertTransformIR(ModuleOp m, const TransformInsertionConfig &config) { } } +struct DetensorizeTransformInsertionPass + : public DetensorizeTransformInsertionBase< + DetensorizeTransformInsertionPass> { + explicit DetensorizeTransformInsertionPass(const std::string &funcAnchor, + const std::string &matchPrefix) + : DetensorizeTransformInsertionBase() { + this->funcAnchorAttr = funcAnchor; + this->matchPrefix = matchPrefix; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + linalg::registerTransformDialectExtension(registry); + } + + static bool isScalarTensorOp(linalg::LinalgOp linalgOp) { + if (!linalgOp.hasTensorSemantics()) + return false; + + if (linalgOp.getNumLoops() != 0) + return false; + + auto isScalarOrScalarTensorOperand = [&](OpOperand &operand) { + if (linalgOp.isScalar(&operand)) + return true; + + auto tensorType = + llvm::dyn_cast(operand.get().getType()); + if (!tensorType) + return false; + + return tensorType.getRank() == 0; + }; + return llvm::all_of(linalgOp->getOpOperands(), + isScalarOrScalarTensorOperand); + } + + void runOnOperation() override { + auto opFilter = [](Operation *op) { + if (auto linalgOp = llvm::dyn_cast_or_null(op)) { + return isScalarTensorOp(linalgOp); + } + return false; + }; + + auto transformBuilder = [](ImplicitLocOpBuilder &b, Operation *, + Value pdlValue) { + b.create(pdlValue); + }; + + insertTransformIR(getOperation(), {funcAnchorAttr, matchPrefix, opFilter, + transformBuilder}); + } +}; + struct FuseExtTransformInsertionPass : public FuseExtTransformInsertionBase { explicit FuseExtTransformInsertionPass( @@ -166,8 +222,46 @@ struct GenericTransformInsertionPass protected: TransformInsertionConfig config; }; + +struct RewriteInDPSTransformInsertionPass + : public RewriteInDPSTransformInsertionBase< + RewriteInDPSTransformInsertionPass> { + explicit RewriteInDPSTransformInsertionPass(const std::string &funcAnchor, + const std::string &matchPrefix) + : RewriteInDPSTransformInsertionBase() { + this->funcAnchorAttr = funcAnchor; + this->matchPrefix = matchPrefix; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + linalg::registerTransformDialectExtension(registry); + } + + void runOnOperation() override { + auto opFilter = [](Operation *op) { + return llvm::isa(op); + }; + + auto transformBuilder = [](ImplicitLocOpBuilder &b, Operation *, + Value pdlValue) { + b.create( + pdlValue.getType(), pdlValue); + }; + + insertTransformIR(getOperation(), {funcAnchorAttr, matchPrefix, opFilter, + transformBuilder}); + } +}; } // namespace +std::unique_ptr> +mlir::createDetensorizeTransformInsertionPass(const std::string &funcAnchor, + const std::string &matchPrefix) { + return std::make_unique(funcAnchor, + matchPrefix); +} + std::unique_ptr> mlir::createFuseExtTransformInsertionPass( const std::string &funcAnchor, const std::string &matchPrefix, @@ -182,4 +276,11 @@ std::unique_ptr> mlir::createGenericTransformInsertionPass( const TransformInsertionConfig &config) { return std::make_unique(config); +} + +std::unique_ptr> +mlir::createRewriteInDPSTransformInsertionPass(const std::string &funcAnchor, + const std::string &matchPrefix) { + return std::make_unique(funcAnchor, + matchPrefix); } \ No newline at end of file diff --git a/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp b/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp index 277174dfa..f5ae6ffea 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp @@ -1698,6 +1698,60 @@ LogicalResult mlir::mhlo::foldReverseWithConstant(mhlo::ReverseOp op, return success(); } +// this pattern match a GatherOp with iota start_indices, +// the output of GatherOp maybe equal to the input. +LogicalResult mlir::mhlo::foldGatherWithInput(mhlo::GatherOp gatherOp, + PatternRewriter &rewriter) { + auto operand = gatherOp.getOperand(); + auto operandTy = operand.getType().cast(); + if (!operandTy.hasRank()) { + return failure(); + } + + auto resultTy = gatherOp.getType().cast(); + if (resultTy != operandTy) { + return failure(); + } + + auto startIndices = gatherOp.getStartIndices(); + auto startIndicesTy = startIndices.getType().cast(); + auto iotaOp = startIndices.getDefiningOp(); + if (!iotaOp || !startIndicesTy.hasRank()) { + return failure(); + } + + int64_t indexVectorDim = startIndicesTy.getRank(); + + auto dimensionNumbers = gatherOp.getDimensionNumbers(); + if (dimensionNumbers.getIndexVectorDim() != indexVectorDim || + indexVectorDim != 1) { + return failure(); + } + + if (dimensionNumbers.getStartIndexMap().size() != 1) { + return failure(); + } + + int64_t startIndexMap = dimensionNumbers.getStartIndexMap()[0]; + auto collapsedSilceDims = dimensionNumbers.getCollapsedSliceDims(); + bool mapTocollapsedDim = false; + + for (auto dims : collapsedSilceDims) { + if (dims == startIndexMap) { + mapTocollapsedDim = true; + break; + } + } + // if the start index and offset index are disjoint, + // and the start index is generate by IotaOp, + // the output of gatherOp is equal to input. + if (mapTocollapsedDim) { + rewriter.replaceOp(gatherOp, operand); + return success(); + } + return failure(); +} + void mlir::mhlo::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, MLIRContext *ctx, bool blindFold) { @@ -1725,6 +1779,7 @@ void mlir::mhlo::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, patterns.add(mlir::mhlo::simplifyCumsumToIota); patterns.add(mlir::mhlo::simplifyTransposeReshapeTranspose); patterns.add(mlir::mhlo::foldReverseWithConstant); + patterns.add(mlir::mhlo::foldGatherWithInput); if (blindFold) { patterns.add(mlir::mhlo::foldLargeConcatenate); } diff --git a/compiler/lib/Dialect/mhlo/Transforms/CatFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/CatFusion.cpp index 6be432b7d..6eac9b56c 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/CatFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/CatFusion.cpp @@ -67,6 +67,8 @@ bool isFusibleWith(Operation *target, Operation * /*start*/) { return true; } bool isValidSingleOp(Operation *op) { return true; } +bool isValidFusionPattern(const MhloFusionPattern &) { return true; } + bool isFusibleCandidateAggressive(Operation *op) { if (isa(op)) return true; @@ -99,14 +101,16 @@ bool isValidSingleOpAggressive(Operation *op) { } static GenericFuserConfig config{ - getByteIRCatFusionAttrName(), cat_fusion::isFusibleCandidate, - cat_fusion::isFusibleStart, cat_fusion::isFusibleTrigger, - cat_fusion::isFusibleWith, cat_fusion::isValidSingleOp}; + getByteIRCatFusionAttrName(), cat_fusion::isFusibleCandidate, + cat_fusion::isFusibleStart, cat_fusion::isFusibleTrigger, + cat_fusion::isFusibleWith, cat_fusion::isValidSingleOp, + cat_fusion::isValidFusionPattern}; static GenericFuserConfig aggressiveConfig{ - getByteIRCatFusionAttrName(), cat_fusion::isFusibleCandidateAggressive, - cat_fusion::isFusibleStart, cat_fusion::isFusibleTrigger, - cat_fusion::isFusibleWith, cat_fusion::isValidSingleOpAggressive}; + getByteIRCatFusionAttrName(), cat_fusion::isFusibleCandidateAggressive, + cat_fusion::isFusibleStart, cat_fusion::isFusibleTrigger, + cat_fusion::isFusibleWith, cat_fusion::isValidSingleOpAggressive, + cat_fusion::isValidFusionPattern}; } // namespace cat_fusion diff --git a/compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp b/compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp index 0f7bd48f5..bca092517 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/ConvertOpToCustomCall.cpp @@ -1,4 +1,4 @@ -//===- ConvertRngToCustomCall.cpp -----------------------------*--- C++ -*-===// +//===- ConvertOpToCustomCall.cpp ------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -74,6 +74,20 @@ func::CallOp getOrCreateCallGetSeedOp(func::FuncOp func, return callGetSeedOp; } +llvm::SmallVector getDefaultAttrs(PatternRewriter &rewriter) { + llvm::SmallVector attrs; + attrs.emplace_back(rewriter.getStringAttr("has_side_effect"), + rewriter.getBoolAttr(false)); + attrs.emplace_back(rewriter.getStringAttr("backend_config"), + rewriter.getStringAttr("")); + attrs.emplace_back(rewriter.getStringAttr("api_version"), + rewriter.getI32IntegerAttr(static_cast( + mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL))); + attrs.emplace_back(rewriter.getStringAttr("called_computations"), + rewriter.getArrayAttr({})); + return attrs; +} + struct ConvertRngUniformToCustomCall : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -120,6 +134,76 @@ struct ConvertRngUniformToCustomCall : public OpRewritePattern { return success(); } }; + +struct ConvertFlashFwdToCustomCall + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::CustomCallOp op, + PatternRewriter &rewriter) const override { + auto opName = op.getCallTargetName(); + if (opName != getFlashAttnFwdName()) + return rewriter.notifyMatchFailure(op, "op name not match"); + + auto resultNum = op.getNumResults(); + if (resultNum != 4) + return rewriter.notifyMatchFailure(op, "op result num not match"); + auto q = op.getOperand(0); + auto k = op.getOperand(1); + auto v = op.getOperand(2); + Type outType = op.getResult(0).getType(); + Type softmaxLseType = op.getResult(1).getType(); + Type softmaxType = op.getResult(2).getType(); + + TensorType seedOrOffsetType = + RankedTensorType::get({}, rewriter.getI64Type()); + + ModuleOp module = op->getParentRegion()->getParentOfType(); + auto functionType = FunctionType::get(module.getContext(), {}, + ArrayRef{seedOrOffsetType}); + func::FuncOp getSeedFunc = getOrCreatePrivateFunctionDeclare( + module, "GetSeedFunc", "GetSeed", functionType); + func::FuncOp nextOffsetFunc = getOrCreatePrivateFunctionDeclare( + module, "NextOffsetFunc", "NextOffset", functionType); + + // avoid to call @getSeed every time + auto getSeedOp = getOrCreateCallGetSeedOp( + op->getParentRegion()->getParentOfType(), getSeedFunc, + rewriter); + auto getOffsetOp = rewriter.create( + op->getLoc(), nextOffsetFunc, ArrayRef{}); + + TensorType seedOrOffsetReshapedType = + RankedTensorType::get({1}, rewriter.getI64Type()); + TensorType rngStateType = RankedTensorType::get({2}, rewriter.getI64Type()); + auto reshapeSeedOp = rewriter.create( + op.getLoc(), seedOrOffsetReshapedType, getSeedOp.getResult(0)); + auto reshapeOffsetOp = rewriter.create( + op.getLoc(), seedOrOffsetReshapedType, getOffsetOp.getResult(0)); + + auto concatOp = rewriter.create( + op.getLoc(), rngStateType, + ValueRange{reshapeSeedOp.getResult(), reshapeOffsetOp.getResult()}, 0); + SmallVector bufferArgs{q, k, v, concatOp.getResult()}; + auto dictAttr = + op->template getAttrOfType(getCustomCallAttrName()); + auto attrs = getDefaultAttrs(rewriter); + attrs.emplace_back(rewriter.getStringAttr("call_target_name"), + rewriter.getStringAttr(getFlashAttnFwdName())); + attrs.emplace_back(rewriter.getStringAttr(getCustomCallAttrName()), + dictAttr); + auto customCallOp = rewriter.create( + op->getLoc(), ArrayRef{outType, softmaxLseType, softmaxType}, + bufferArgs, ArrayRef{attrs}); + Value outPad = customCallOp.getResult(0); + Value softmaxLse = customCallOp.getResult(1); + Value softmaxReturn = customCallOp.getResult(2); + ValueRange results{outPad, softmaxLse, softmaxReturn, concatOp.getResult()}; + rewriter.replaceOp(op, results); + return success(); + } +}; + struct ConvertOpToCustomCallPass : public ConvertOpToCustomCallBase { @@ -140,6 +224,7 @@ struct ConvertOpToCustomCallPass RewritePatternSet patterns(context); populateRngPatternToCustomCall(patterns); + populateFlashFwdRewritePattern(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns))) { @@ -155,6 +240,10 @@ void mlir::populateRngPatternToCustomCall(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +void mlir::populateFlashFwdRewritePattern(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + std::unique_ptr> mlir::createConvertOpToCustomCallPass(llvm::StringRef anchor) { return std::make_unique(anchor); diff --git a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp index 0caf95fbf..894c5866c 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp @@ -96,10 +96,13 @@ bool isValidSingleOp(Operation *op) { isCustomMhloRngOp(op); } +bool isValidFusionPattern(const MhloFusionPattern &) { return true; } + static GenericFuserConfig config{ getByteIRElementwiseFusionAttrName(), elementwise::isFusibleCandidate, elementwise::isFusibleStart, elementwise::isFusibleTrigger, - elementwise::isFusibleWith, elementwise::isValidSingleOp}; + elementwise::isFusibleWith, elementwise::isValidSingleOp, + elementwise::isValidFusionPattern}; } // namespace elementwise @@ -126,15 +129,85 @@ bool isFusibleWith(Operation * /*target*/, Operation * /*start*/) { bool isValidSingleOp(Operation *op) { return false; } +bool isValidFusionPattern(const MhloFusionPattern &) { return true; } + static GenericFuserConfig config{getByteIRMatmulEpilogueFusionAttrName(), matmul_epilogue::isFusibleCandidate, matmul_epilogue::isFusibleStart, matmul_epilogue::isFusibleTrigger, matmul_epilogue::isFusibleWith, - matmul_epilogue::isValidSingleOp}; + matmul_epilogue::isValidSingleOp, + matmul_epilogue::isValidFusionPattern}; } // namespace matmul_epilogue +namespace reduction { +// TODO: maybe we should support non-splat constant on device in future +bool isFusibleCandidate(Operation *op) { + return isMhlo(op) && (op->hasTrait<::mlir::OpTrait::Elementwise>() || + op->hasTrait() || + isSplatMhloConstantLike(op) || + isa(op)); +} + +// every candidate can start +bool isFusibleStart(Operation *op) { return true; } + +bool isFusibleTrigger(Operation *op) { + if (op->hasTrait<::mlir::OpTrait::Elementwise>() || + op->hasTrait() || + isa(op)) { + return true; + } + + // if broadcast, check whether its operand is only used in broadcast + if (isa(op)) { + auto src = op->getOperand(0); + // is foldable we just allow + if (isDeepMhloFoldable(src.getDefiningOp())) { + return true; + } + // otherwise, check it is only used in broadcast + // return useCount(src) == 1; + // LWC FIXME: change back to above after broadcast fusion resolve. + return false; + } + + if (isa(op)) + return true; + + return false; +} + +bool isFusibleWith(Operation *target, Operation * /*start*/) { + return (target->hasTrait<::mlir::OpTrait::Elementwise>() || + target->hasTrait() || + isSplatMhloConstantLike(target) || + isa( + target)) && + target->hasOneUse(); +} + +bool isValidSingleOp(Operation *op) { return isa(op); } + +bool isValidFusionPattern(const MhloFusionPattern &pattern) { + SmallVector outputs = getOutputsOfCluster(pattern); + if (outputs.size() == 1) { + if (outputs[0].getDefiningOp()) + return true; + } + return false; +} + +static GenericFuserConfig config{ + getByteIRReductionFusionAttrName(), reduction::isFusibleCandidate, + reduction::isFusibleStart, reduction::isFusibleTrigger, + reduction::isFusibleWith, reduction::isValidSingleOp, + reduction::isValidFusionPattern}; + +} // namespace reduction + // a derived fusion pass for elementwise struct ElementwiseFusionPass : public GenericFusionPass { @@ -188,6 +261,29 @@ struct MatmulEpilogueFusionPass ::llvm::StringRef getName() const override { return "MatmulEpilogueFusion"; } }; +// a derived fusion pass for reduction fusion +struct ReductionFusionPass : public GenericFusionPass { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReductionFusionPass) + + ReductionFusionPass() : GenericFusionPass(reduction::config, false) {} + + /// Returns the command-line argument attached to this pass. + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral("fuse-reduction"); + } + ::llvm::StringRef getArgument() const override { return "fuse-reduction"; } + + ::llvm::StringRef getDescription() const override { + return "Fuse reduction with its producer"; + } + + /// Returns the derived pass name. + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral("ReductionFusion"); + } + ::llvm::StringRef getName() const override { return "ReductionFusion"; } +}; } // namespace std::unique_ptr> @@ -199,3 +295,7 @@ std::unique_ptr> mlir::createMatmulEpilogueFusionPass() { return std::make_unique(); } + +std::unique_ptr> mlir::createReductionFusionPass() { + return std::make_unique(); +} diff --git a/compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp index 678d6bc9c..6db573e5f 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp @@ -47,12 +47,15 @@ bool isFusibleWith(Operation *, Operation *) { return true; } bool isValidSingleOp(Operation *) { return true; } +bool isValidFusionPattern(const MhloFusionPattern &) { return true; } + static GenericFuserConfig config{getByteIRHloAggressiveFusionAttrName(), aggressive_fusion::isFusibleCandidate, aggressive_fusion::isFusibleStart, aggressive_fusion::isFusibleTrigger, aggressive_fusion::isFusibleWith, - aggressive_fusion::isValidSingleOp}; + aggressive_fusion::isValidSingleOp, + aggressive_fusion::isValidFusionPattern}; } // namespace aggressive_fusion diff --git a/compiler/lib/Dialect/mhlo/Transforms/HloMoveDown.cpp b/compiler/lib/Dialect/mhlo/Transforms/HloMoveDown.cpp index 55b74ef26..7bcbd1371 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/HloMoveDown.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/HloMoveDown.cpp @@ -59,12 +59,21 @@ struct TransposeMoveDownPattern : public HloMoveDownPattern { LogicalResult matchAndRewrite(mhlo::TransposeOp op, PatternRewriter &rewriter) const override { auto value = op.getResult(); - auto operandType = op.getOperand().getType(); // T1 as Transpose: T1 -> T2 - // early termination if not allMultiUser nor multiUser but has multi users if (!allMultiUser && !multiUser && userCount(value) != 1) { return failure(); } + auto permutationAttr = op.getPermutation(); + + auto isTransposeWithSamePermutation = + [&permutationAttr](Value val) -> bool { + auto op = val.getDefiningOp(); + if (!op) { + return false; + } else { + return op.getPermutation() == permutationAttr; + } + }; llvm::SetVector users; for (auto user : value.getUsers()) { @@ -94,13 +103,19 @@ struct TransposeMoveDownPattern : public HloMoveDownPattern { // isElementwiseOneResult(user) == true bool failed = false; for (auto operand : user->getOperands()) { - if (operand != value && !isSplatMhloConstantValue(operand)) { - if (allMultiUser) - return failure(); - failed = true; - break; + if (operand == value) { + continue; + } else if (isDenseMhloConstantValue(operand)) { + continue; + } else if (isTransposeWithSamePermutation(operand)) { + continue; } + if (allMultiUser) + return failure(); + failed = true; + break; } + if (failed) continue; users.insert(user); @@ -119,8 +134,10 @@ struct TransposeMoveDownPattern : public HloMoveDownPattern { if (!bvm.contains(value)) { bvm.map(value, op.getOperand()); } + } else if (isTransposeWithSamePermutation(operand)) { + bvm.map(operand, operand.getDefiningOp().getOperand()); } else { - // isSplatMhloConstantValue(operand) == true + // isDenseMhloConstantValue(operand) == true // since it has been checked when collecting users if (!constInputs.contains(operand)) { constInputs.insert(operand); @@ -130,14 +147,19 @@ struct TransposeMoveDownPattern : public HloMoveDownPattern { // create all const and put into bvm for (auto input : constInputs) { - ElementsAttr oldConstAttr = - input.getDefiningOp().getValue(); - auto newConstAttr = reshapeSplatElementsAttr(oldConstAttr, operandType); - auto newConstOp = - rewriter.create(op->getLoc(), *newConstAttr); - bvm.map(input, newConstOp.getOutput()); + SmallVector newPermutation(permutationAttr.size()); + std::for_each(permutationAttr.value_begin(), + permutationAttr.value_end(), + [i = 0, &newPermutation](auto e) mutable { + newPermutation[e.getSExtValue()] = (uint64_t)i++; + }); + auto newPermutationAttr = DenseIntElementsAttr::get( + permutationAttr.getType(), newPermutation); + auto ConstOp = input.getDefiningOp(); + auto newTransposeOp = rewriter.create( + ConstOp.getLoc(), ConstOp.getOutput(), newPermutationAttr); + bvm.map(input, newTransposeOp.getResult()); } - auto maybeResultTypes = mixTypes(/*cloneFromElementTypes*/ user->getResultTypes(), /*cloneFromShapes*/ op->getOperandTypes()); @@ -145,6 +167,8 @@ struct TransposeMoveDownPattern : public HloMoveDownPattern { // maybeResultTypes should always have value assert(maybeResultTypes.has_value()); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(user); // clone an elementwise op as producer auto newProducer = cloneAndReplaceResultTypes(rewriter, user, bvm, *maybeResultTypes); diff --git a/compiler/lib/Pipelines/BufferizeOpt.cpp b/compiler/lib/Pipelines/BufferizeOpt.cpp index dcbb11485..4a5f2c5bf 100644 --- a/compiler/lib/Pipelines/BufferizeOpt.cpp +++ b/compiler/lib/Pipelines/BufferizeOpt.cpp @@ -23,6 +23,7 @@ #include "byteir/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Transforms/Passes.h" #include "transforms/passes.h" @@ -37,6 +38,7 @@ void mlir::createByteIRBufferizeOptPipeline( pm.addPass(byteir::createOneShotBufferizePass()); addCleanUpExtPassPipeline(pm); + pm.addNestedPass(memref::createFoldMemRefAliasOpsPass()); // clean-up possible redundant copy from bufferization // perform twice, since cse is not greedy-based pm.addNestedPass(createRemoveCopyPass()); diff --git a/compiler/lib/Pipelines/ByreOpt.cpp b/compiler/lib/Pipelines/ByreOpt.cpp index 0e78f8742..ad4bba814 100644 --- a/compiler/lib/Pipelines/ByreOpt.cpp +++ b/compiler/lib/Pipelines/ByreOpt.cpp @@ -48,7 +48,10 @@ void createByreOptPipelineImpl(OpPassManager &pm, const std::string &entryFunc, OpPassManager anchoredPM(func::FuncOp::getOperationName()); if (!disableMemoryPlanning) { // underlying memory of constant op cannot be reused - anchoredPM.addPass(createMemoryPlanningPass(128, nullptr)); + anchoredPM.addPass(createMemoryPlanningPass(/* alignment */ 128, + /* alloca */ false, + /* memory space */ 0, + /* callback */ nullptr)); anchoredPM.addPass(createCanonicalizerPass()); } anchoredPM.addPass(createConvertMemrefToByrePass()); diff --git a/compiler/lib/Pipelines/GPU/CMakeLists.txt b/compiler/lib/Pipelines/GPU/CMakeLists.txt index 4ab1ad0a7..8eea9ad17 100644 --- a/compiler/lib/Pipelines/GPU/CMakeLists.txt +++ b/compiler/lib/Pipelines/GPU/CMakeLists.txt @@ -2,7 +2,9 @@ add_mlir_library(ByteIRGPUPipelines ElementwiseCodegen.cpp GPUOpt.cpp LinalgMemrefGPU.cpp + MappingForall.cpp NVVMCodegen.cpp + ReductionCodegen.cpp ADDITIONAL_HEADER_DIRS ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Pipelines/GPU @@ -15,6 +17,7 @@ add_mlir_library(ByteIRGPUPipelines MLIRBufferTransforms LINK_LIBS PUBLIC + ByteIRGPUPasses ByteIRLinalgPasses ByteIRPipelineCommon ByteIRUtils @@ -22,4 +25,5 @@ add_mlir_library(ByteIRGPUPipelines ByteIRToPTX MLIRIR MLIRTransforms + MLIRLinalgExtTransformOps ) \ No newline at end of file diff --git a/compiler/lib/Pipelines/GPU/GPUOpt.cpp b/compiler/lib/Pipelines/GPU/GPUOpt.cpp index e5cf725b8..842901b04 100644 --- a/compiler/lib/Pipelines/GPU/GPUOpt.cpp +++ b/compiler/lib/Pipelines/GPU/GPUOpt.cpp @@ -20,10 +20,14 @@ #include "byteir/Conversion/ToGPU/ToGPU.h" #include "byteir/Conversion/ToPTX/ToPTX.h" #include "byteir/Dialect/Affine/Passes.h" +#include "byteir/Dialect/GPU/Passes.h" #include "byteir/Dialect/SCF/Passes.h" +#include "byteir/Dialect/Transform/Transforms/TransformDialectInterpreter.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/Common/Utils.h" +#include "byteir/Pipelines/GPU/MappingForall.h" #include "byteir/Transforms/Passes.h" +#include "byteir/Transforms/RemoveFuncBody.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" @@ -35,8 +39,9 @@ using namespace mlir; using namespace mlir::bufferization; namespace { -void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv, - const std::string &target) { +void createElementwiseGPUOptPipelineImpl(OpPassManager &pm, + const bool &useBarePtrCallConv, + const std::string &target) { // apply PromotoBufferStack to func's with // getByteIRElementwiseFusionAttrName { @@ -73,6 +78,36 @@ void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv, pm.addNestedPass(createGenPTXConfigPass(useBarePtrCallConv)); } +void createReductionGPUOptPipelineImpl(OpPassManager &pm) { + GPUMappingForallOptions options; + options.funcAnchor = getByteIRReductionFusionAttrName().str(); + createGPUMappingForallTransform(pm, options); + pm.addPass(createTransformDialectInterpreter(true)); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createGpuLauchSinkIndexComputationsPass()); + + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + + anchoredPM.addPass(createPromoteBuffersToStackPass( + /*isSmallAlloc =*/[](Value value) { + return value.getParentRegion()->getParentOfType(); + })); + + pm.addNestedPass(createAnchoredPipelinePass( + getByteIRReductionFusionAttrName(), anchoredPM)); + } + pm.addPass(createGpuKernelOutliningPass()); +} + +void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv, + const std::string &target) { + createElementwiseGPUOptPipelineImpl(pm, useBarePtrCallConv, target); + createReductionGPUOptPipelineImpl(pm); + pm.addPass(createCollectGPUKernelPass("unified", false)); +} + } // namespace void mlir::createGPUOptPipeline(OpPassManager &pm, diff --git a/compiler/lib/Pipelines/GPU/MappingForall.cpp b/compiler/lib/Pipelines/GPU/MappingForall.cpp new file mode 100644 index 000000000..633214bee --- /dev/null +++ b/compiler/lib/Pipelines/GPU/MappingForall.cpp @@ -0,0 +1,148 @@ +//===- MappingForall.cpp --------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Pipelines/GPU/MappingForall.h" + +#include "byteir/Conversion/ToGPU/ToGPU.h" +#include "byteir/Conversion/ToLLVM/ToLLVM.h" +#include "byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.h" +#include "byteir/Dialect/Transform/IR/TransformExtOps.h" +#include "byteir/Dialect/Transform/Transforms/TransformInsertion.h" +#include "byteir/Pipelines/Common/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallSet.h" + +#include + +using namespace mlir; + +namespace { + +static constexpr int64_t kMaximumBlockDim = 1024; + +struct MappingForallConfig { + SmallVector blockDims; +}; + +// TODO: move to common helper +bool isMappedToGPUBlocks(scf::ForallOp forallOp) { + if (auto mapping = forallOp.getMappingAttr()) { + if (llvm::any_of(mapping.getValue(), [](Attribute attr) { + return isa(attr); + })) { + return true; + } + } + + return false; +} + +bool isMappedToGPUThreads(scf::ForallOp forallOp) { + if (auto mapping = forallOp.getMappingAttr()) { + if (llvm::any_of(mapping.getValue(), [](Attribute attr) { + return isa(attr); + })) { + return true; + } + } + + return false; +} + +void updateBlockDims(scf::ForallOp forallOp, SmallVector &blockDims) { + for (auto &&[lb, ub, step, mappingAttr] : llvm::zip( + forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), forallOp.getMappingAttr().getValue())) { + if (auto threadMapping = + llvm::dyn_cast_or_null(mappingAttr)) { + auto numIterations = constantTripCount(lb, ub, step); + auto threadIdx = threadMapping.getMappingId(); + if (numIterations.has_value()) { + blockDims[threadIdx] = + std::max(blockDims[threadIdx], numIterations.value()); + } + } + } +} + +std::optional +getMappingForallConfig(scf::ForallOp forallOp) { + if (!isMappedToGPUBlocks(forallOp)) + return std::nullopt; + + SmallVector blockDims{1, 1, 1}; + auto &&block = forallOp.getRegion().front(); + for (auto &&nestedForall : block.getOps()) { + if (isMappedToGPUThreads(nestedForall)) { + updateBlockDims(nestedForall, blockDims); + } + } + + if (blockDims[0] * blockDims[1] * blockDims[2] > kMaximumBlockDim) { + return std::nullopt; + } + return MappingForallConfig{blockDims}; +} + +void createGPUMappingForallTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (auto forallOp = llvm::dyn_cast_or_null(op)) { + return getMappingForallConfig(forallOp).has_value(); + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + auto mappingConfig = + getMappingForallConfig(llvm::cast(op)).value(); + auto pdlType = pdl::OperationType::get(b.getContext()); + auto launchOp = b.create( + /* result type */ pdlType, + /* target */ pdlV, + /* grid_dims */ llvm::ArrayRef{}, + /* generate_gpu_launch */ true); + + b.create( + /* result type*/ pdlType, + /* target */ launchOp.getResult(), + /* block_dims */ mappingConfig.blockDims, + /* warp_dims */ llvm::ArrayRef{}, + /* sync_after_distribute*/ true); + }; + + pm.addPass(createGenericTransformInsertionPass(config)); +} +} // namespace + +void mlir::createGPUMappingForallTransform( + OpPassManager &pm, const GPUMappingForallOptions &options) { + invokeOpPassPipelineBuilder(createGPUMappingForallTransformImpl, pm, + options.funcAnchor, options.annotatePrefix); +} diff --git a/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp b/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp index aecfc5b84..d23567a58 100644 --- a/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp @@ -19,6 +19,7 @@ #include "byteir/Conversion/GPUToNVVM/GPUToNVVM.h" #include "byteir/Conversion/ToPTX/ToPTX.h" +#include "byteir/Dialect/GPU/Passes.h" #include "byteir/Dialect/MemRef/Transforms/ExtractAddressComputation.h" #include "byteir/Dialect/MemRef/Transforms/SimplifyLinearizedIndex.h" #include "byteir/Dialect/mhlo/Passes.h" @@ -39,6 +40,9 @@ void createNVVMCodegenPipelineImpl(OpPassManager &pm, // TODO add target for supporting different SMs // TODO use target to decide passes pm.addPass(createCollectGPUKernelPass()); + pm.addNestedPass(createShmAllocaToWorkgroupArg()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); pm.addPass(createConvertSCFToCFPass()); pm.addPass(createExtractAddressComputationPass()); pm.addPass(memref::createExpandStridedMetadataPass()); diff --git a/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp b/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp new file mode 100644 index 000000000..664fce40e --- /dev/null +++ b/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp @@ -0,0 +1,942 @@ +//===- ReductionCodegen.cpp ---------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Pipelines/GPU/ReductionCodegen.h" + +#include "byteir/Conversion/ToGPU/ToGPU.h" +#include "byteir/Conversion/ToLLVM/ToLLVM.h" +#include "byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.h" +#include "byteir/Dialect/Transform/IR/TransformExtOps.h" +#include "byteir/Dialect/Transform/Transforms/TransformInsertion.h" +#include "byteir/Pipelines/Common/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallSet.h" + +#include + +using namespace mlir; + +namespace { +//----------------------------------------------------------------------------// +// common helpers +//----------------------------------------------------------------------------// +// TODO: move to common header + +constexpr bool isPowerOf2(int64_t n) { return (!(n & (n - 1))); } + +constexpr int64_t nextPowerOf2(int64_t n) { + return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2))); +} + +bool isMappedToGPUBlocks(scf::ForOp forOp) { + if (auto loopToSIMTAttr = + forOp->getAttrOfType(getLoopToSIMTAttrName())) { + auto mappingTo = loopToSIMTAttr.getValue(); + if (mappingTo == getBlockIdXName() || mappingTo == getBlockIdYName() || + mappingTo == getBlockIdZName()) { + return true; + } + } + return false; +} + +bool isMappedToGPUBlocks(scf::ForallOp forallOp) { + if (auto mapping = forallOp.getMappingAttr()) { + if (llvm::any_of(mapping.getValue(), [](Attribute attr) { + return isa(attr); + })) { + return true; + } + } + + return false; +} + +bool isMappedToGPUBlocks(Operation *op) { + if (auto forOp = llvm::dyn_cast_or_null(op)) { + return isMappedToGPUBlocks(forOp); + } + if (auto forallOp = llvm::dyn_cast_or_null(op)) { + return isMappedToGPUBlocks(forallOp); + } + return false; +} + +bool isMappedToGPUThreads(scf::ForOp forOp) { + if (auto loopToSIMTAttr = + forOp->getAttrOfType(getLoopToSIMTAttrName())) { + auto mappingTo = loopToSIMTAttr.getValue(); + if (mappingTo == getThreadIdXName() || mappingTo == getThreadIdYName() || + mappingTo == getThreadIdZName()) { + return true; + } + } + return false; +} + +bool isMappedToGPUThreads(scf::ForallOp forallOp) { + if (auto mapping = forallOp.getMappingAttr()) { + if (llvm::any_of(mapping.getValue(), [](Attribute attr) { + return isa(attr); + })) { + return true; + } + } + + return false; +} + +bool isMappedToGPUThreads(Operation *op) { + if (auto forOp = llvm::dyn_cast_or_null(op)) { + return isMappedToGPUThreads(forOp); + } + if (auto forallOp = llvm::dyn_cast_or_null(op)) { + return isMappedToGPUThreads(forallOp); + } + return false; +} + +uint64_t getNumTiledLoops(ArrayRef tileSizes) { + return llvm::count_if(tileSizes, + [](int64_t tileSize) { return tileSize > 0; }); +} + +std::optional getReductionDim(linalg::GenericOp genericOp) { + SmallVector reductionDims; + genericOp.getReductionDims(reductionDims); + if (reductionDims.size() == 1) { + return reductionDims[0]; + } + return std::nullopt; +} + +std::optional getOperandReductionDim(OpOperand &operand) { + auto genericOp = llvm::dyn_cast(operand.getOwner()); + if (!genericOp) + return std::nullopt; + + auto dim = getReductionDim(genericOp); + if (!dim.has_value()) + return std::nullopt; + + auto affineMap = genericOp.getIndexingMapsArray()[operand.getOperandNumber()]; + if (!affineMap || !affineMap.isProjectedPermutation()) + return std::nullopt; + + for (auto &&en : llvm::enumerate(affineMap.getResults())) { + if (auto dimExpr = en.value().dyn_cast()) { + if (dimExpr.getPosition() == *dim) { + return en.index(); + } + } + } + + return std::nullopt; +} + +SmallVector getDynamicDims(linalg::GenericOp genericOp) { + auto staticLoopRanges = genericOp.getStaticLoopRanges(); + SmallVector ret; + for (int64_t i = 0; i < staticLoopRanges.size(); ++i) { + if (ShapedType::isDynamic(staticLoopRanges[i])) { + ret.push_back(i); + } + } + return ret; +} + +//----------------------------------------------------------------------------// +// configuration structs +//----------------------------------------------------------------------------// + +static constexpr StringLiteral kGridReduction = "__grid_reduction__"; +static constexpr StringLiteral kBlockReduction = "__block_reduction__"; +static constexpr StringLiteral kWarpReduction = "__warp_reduction__"; +static constexpr StringLiteral kThreadReduction = "__thread_reduction__"; + +struct ProducerSelector { + uint64_t operandNumber; + llvm::StringRef opName; + std::vector producerSelectors; + + ProducerSelector(uint64_t operandNumber, llvm::StringRef opName) + : operandNumber(operandNumber), opName(opName) {} + + static bool detectFillOperand(OpOperand *opOperand, + std::vector &selectors) { + if (opOperand->get().getDefiningOp()) { + selectors.emplace_back(opOperand->getOperandNumber(), + linalg::FillOp::getOperationName()); + return true; + } + return false; + } + + static bool detectPadOperand(OpOperand *opOperand, + std::vector &selectors) { + Operation *definingOp = opOperand->get().getDefiningOp(); + if (!definingOp) + return false; + + if (llvm::isa(definingOp)) { + ProducerSelector selector(opOperand->getOperandNumber(), + definingOp->getName().getStringRef()); + if (detectPadOperand(&definingOp->getOpOperand(0), + selector.producerSelectors)) { + selectors.emplace_back(std::move(selector)); + return true; + } + } else if (llvm::isa(definingOp)) { + selectors.emplace_back(opOperand->getOperandNumber(), + tensor::PadOp::getOperationName()); + return true; + } + return false; + } +}; + +struct GridSplitConfig { + int64_t splitFactor; + int64_t dimension; + + void apply(ImplicitLocOpBuilder &b, Value pdlV); +}; + +struct GridTileConfig { + SmallVector tileSizes; + SmallVector mapping; + std::vector fuseCandidates; + + void apply(ImplicitLocOpBuilder &b, Value pdlV, bool usingForall); +}; + +struct BlockSplitConfig { + SmallVector splitFactors; + SmallVector dimensions; + SmallVector padDims; + SmallVector padValues; + + void apply(ImplicitLocOpBuilder &b, Value pdlV); +}; + +struct BlockTileConfig { + SmallVector tileSizes; + SmallVector mapping; + std::vector fuseCandidates; + + void apply(ImplicitLocOpBuilder &b, Value pdlV, bool usingForall); +}; + +struct ThreadTileConfig { + SmallVector parallelTileSizes; + SmallVector reductionTileSizes; + SmallVector unrollFactors; + std::vector initOperands; + + void apply(ImplicitLocOpBuilder &b, Value pdlV); +}; + +void processProducerSelectors( + ImplicitLocOpBuilder &b, + const std::vector &producerSelectors, Value fuseInto, + SmallVector &selected, Type producerType = nullptr) { + for (auto selector : producerSelectors) { + auto producer = b.create( + /* producer type */ producerType + ? producerType + : transform::OperationType::get(b.getContext(), selector.opName), + /* target */ fuseInto, + /* operand number */ selector.operandNumber); + selected.push_back(producer.getProducer()); + processProducerSelectors(b, selector.producerSelectors, selected.back(), + selected); + } +} + +void tileToForallAndFuseImpl( + ImplicitLocOpBuilder &b, Value toTile, + const SmallVector &tileSizes, + const SmallVector &mapping, + const std::vector &fuseCandidates) { + SmallVector toBeFused; + processProducerSelectors(b, fuseCandidates, toTile, toBeFused); + + auto tileOp = b.create( + /* target */ toTile, + /* staticTileSizes */ tileSizes, + /* ctor tag */ transform::TileSizesSpec(), + /* mapping */ b.getArrayAttr(mapping)); + for (auto &&producerOp : toBeFused) { + b.create( + /* producerOp */ producerOp, + /* containingOp */ tileOp.getForallOp()); + } +} + +void tileToSCFForAndFuseImpl(ImplicitLocOpBuilder &b, Value toTile, + const SmallVector &tileSizes, + const SmallVector &mapping) { + auto pdlType = pdl::OperationType::get(b.getContext()); + auto fuseOp = b.create( + /* transformed */ pdlType, + /* loops */ + SmallVector(getNumTiledLoops(tileSizes), pdlType), + /* target */ toTile, + /* tile_sizes */ b.getI64ArrayAttr(tileSizes), + /* tile_interchange */ ArrayAttr()); + for (auto &&[loop, mapTo] : llvm::zip(fuseOp.getLoops(), mapping)) { + Value paramV = b.create( + /* type */ pdl::AttributeType::get(b.getContext()), + /* value */ mapTo); + b.create( + /* target */ loop, + /* name */ getLoopToSIMTAttrName(), + /* param */ paramV); + } +} + +void GridSplitConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { + if (splitFactor) { + auto splitted = b.create( + /* target */ pdlV, + /* splitFactor */ splitFactor, + /* insertSplitDimension */ dimension, + /* innerParallel */ false, + /* useScalingAlgorithm */ false, + /* useAlloc */ false); + b.create( + /* target */ splitted.getSplitLinalgOp(), + /* name */ kGridReduction, + /* param */ Value()); + b.create( + /* target */ splitted.getCombiningLinalgOp(), + /* name */ kGridReduction, + /* param */ Value()); + } else { + b.create( + /* target */ pdlV, + /* name */ kGridReduction, + /* param */ Value()); + } +} + +void GridTileConfig::apply(ImplicitLocOpBuilder &b, Value pdlV, + bool usingForall) { + if (usingForall) { + auto mappingAttrs = llvm::to_vector( + llvm::map_range(mapping, [&](gpu::Blocks dim) -> Attribute { + return gpu::GPUBlockMappingAttr::get(b.getContext(), dim); + })); + tileToForallAndFuseImpl(b, pdlV, tileSizes, mappingAttrs, fuseCandidates); + } else { + static constexpr std::array mappings{ + getBlockIdXName(), getBlockIdYName(), getBlockIdZName()}; + auto mappingAttrs = llvm::to_vector( + llvm::map_range(mapping, [&](gpu::Blocks dim) -> Attribute { + return b.getStringAttr(mappings[static_cast(dim)]); + })); + tileToSCFForAndFuseImpl(b, pdlV, tileSizes, mappingAttrs); + } +} + +void BlockSplitConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { + if (!padDims.empty()) { + auto padOp = b.create( + TypeRange{pdlV.getType(), pdlV.getType()}, pdlV, + /*padding_values=*/b.getArrayAttr(padValues), + /*padding_dimensions=*/ + b.getI64ArrayAttr(padDims), + /*padToMultipleOf=*/ArrayAttr{}, + /*pack_paddings=*/ArrayAttr{}, + /*transpose_paddings=*/ArrayAttr{}, + /*copyBack=*/false); + pdlV = padOp.getPadded(); + } + if (!splitFactors.empty()) { + Value toSplit = pdlV; + for (auto &&[splitFactor, redDim] : llvm::zip(splitFactors, dimensions)) { + auto splitted = b.create( + /* target */ toSplit, + /* splitFactor */ splitFactor, + /* insertSplitDimension */ redDim, + /* innerParallel */ false, + /* useScalingAlgorithm */ false, + /* useAlloc */ false); + b.create( + /* target */ splitted.getInitOrAllocOp(), + /* name */ kBlockReduction, + /* param */ Value()); + b.create( + /* target */ splitted.getCombiningLinalgOp(), + /* name */ kBlockReduction, + /* param */ Value()); + toSplit = splitted.getCombiningLinalgOp(); + } + pdlV = toSplit; + } else { + b.create( + /* target */ pdlV, + /* name */ kBlockReduction, + /* param */ Value()); + } + auto func = b.create( + pdlV.getType(), pdlV, + /* isolated_from_above */ true, + /* op_name */ b.getStringAttr(func::FuncOp::getOperationName()), + /* deduplicate */ false); + b.create(func, [](OpBuilder &b, Location loc) { + b.create(loc); + }); + auto forall = b.create( + pdlV.getType(), pdlV, + /* isolated_from_above */ false, + /* op_name */ b.getStringAttr(scf::ForallOp::getOperationName()), + /* deduplicate */ false); + if (!padDims.empty()) { + auto parallelInsertSliceType = transform::OperationType::get( + b.getContext(), tensor::ParallelInsertSliceOp::getOperationName()); + auto parallelInsertSlice = b.create( + parallelInsertSliceType, forall, + tensor::ParallelInsertSliceOp::getOperationName()); + b.create(pdlV.getType(), + parallelInsertSlice); + } + auto emptyTensorType = transform::OperationType::get( + b.getContext(), tensor::EmptyOp::getOperationName()); + auto emptyTensor = b.create( + emptyTensorType, forall, tensor::EmptyOp::getOperationName()); + auto allocTensorType = transform::OperationType::get( + b.getContext(), bufferization::AllocTensorOp::getOperationName()); + auto allocTensor = b.create( + allocTensorType, emptyTensor); + auto memorySpaceAttrName = + bufferization::AllocTensorOp::getMemorySpaceAttrName(OperationName( + bufferization::AllocTensorOp::getOperationName(), b.getContext())); + auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get( + b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + Value paramV = b.create( + /* type */ pdl::AttributeType::get(b.getContext()), + /* value */ workgroupMemoryAddressSpace); + b.create( + /* target */ allocTensor, + /* name */ memorySpaceAttrName, + /* param */ paramV); +} + +void BlockTileConfig::apply(ImplicitLocOpBuilder &b, Value pdlV, + bool usingForall) { + if (usingForall) { + auto mappingAttrs = llvm::to_vector( + llvm::map_range(mapping, [&](gpu::Threads dim) -> Attribute { + return gpu::GPUThreadMappingAttr::get(b.getContext(), dim); + })); + tileToForallAndFuseImpl(b, pdlV, tileSizes, mappingAttrs, fuseCandidates); + } else { + static constexpr std::array mappings{ + getThreadIdXName(), getThreadIdYName(), getThreadIdZName()}; + auto mappingAttrs = llvm::to_vector( + llvm::map_range(mapping, [&](gpu::Threads dim) -> Attribute { + return b.getStringAttr(mappings[static_cast(dim)]); + })); + tileToSCFForAndFuseImpl(b, pdlV, tileSizes, mappingAttrs); + } +} + +void ThreadTileConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { + auto pdlType = pdl::OperationType::get(b.getContext()); + auto numTiledParallelLoops = getNumTiledLoops(parallelTileSizes); + SmallVector loops; + if (numTiledParallelLoops > 0) { + auto fuseOp = b.create( + /* transformed */ pdlType, + /* loops */ + SmallVector(getNumTiledLoops(parallelTileSizes), pdlType), + /* target */ pdlV, + /* tile_sizes */ b.getI64ArrayAttr(parallelTileSizes), + /* tile_interchange */ ArrayAttr()); + loops = fuseOp.getLoops(); + pdlV = fuseOp.getTransformed(); + } + + auto tileOp = b.create( + /* target */ pdlV, + /* tillSizes */ reductionTileSizes); + loops.push_back(tileOp.getLoops()[0]); + for (auto &&[loop, factor] : llvm::reverse(llvm::zip(loops, unrollFactors))) { + b.create(loop, factor); + } +} + +//----------------------------------------------------------------------------// +// codegen strategies +//----------------------------------------------------------------------------// + +bool isReductionOp(linalg::GenericOp genericOp) { + if (genericOp.getNumReductionLoops() != 1) + return false; + + if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap affineMap) { + return affineMap.isProjectedPermutation(/* allowZeroInResults */ false); + })) + return false; + + return true; +} + +bool isGridReductionOp(linalg::GenericOp genericOp) { + if (!isReductionOp(genericOp)) + return false; + + // early return for manual tag + if (genericOp->hasAttr(kGridReduction)) + return true; + + // top level generic op in function + if (genericOp->getParentOfType()) + return true; + + return false; +} + +bool isBlockReductionOp(linalg::GenericOp genericOp) { + if (!isReductionOp(genericOp)) + return false; + + // early return for manual tag + if (genericOp->hasAttr(kBlockReduction)) + return true; + + // nested in op which is mapped to GPU blocks + if (isMappedToGPUBlocks(genericOp->getParentOp())) + return true; + + return false; +} + +bool isThreadReductionOp(linalg::GenericOp genericOp) { + if (!isReductionOp(genericOp)) + return false; + + // early return for manual tag + if (genericOp->hasAttr(kThreadReduction)) + return true; + + // nested in op which is mapped to GPU threads + if (isMappedToGPUThreads(genericOp->getParentOp())) + return true; + + return false; +} + +std::optional getGridSplitConfig(linalg::GenericOp genericOp, + int64_t splitFactor) { + if (!isGridReductionOp(genericOp)) + return std::nullopt; + + auto redDim = *getReductionDim(genericOp); + auto staticLoopRanges = genericOp.getStaticLoopRanges(); + if (ShapedType::isDynamic(staticLoopRanges[redDim]) || + staticLoopRanges[redDim] % splitFactor != 0 || + staticLoopRanges[redDim] <= 1024) + return std::nullopt; + + return GridSplitConfig{splitFactor, redDim ? redDim - 1 : redDim}; +} + +std::optional getGridTileConfig(linalg::GenericOp genericOp, + int64_t warpSize, + int64_t blockSize) { + if (!isGridReductionOp(genericOp)) + return std::nullopt; + + int64_t numLoops = genericOp.getNumLoops(); + SmallVector tileSizes(numLoops, 1); + auto loopSizes = + cast(genericOp.getOperation()).computeStaticLoopSizes(); + + for (auto &&affineMap : genericOp.getIndexingMapsArray()) { + if (affineMap.isPermutation()) { + auto dim = affineMap.getDimPosition(numLoops - 1); + if (loopSizes[dim] > warpSize) { // TODO: padding + tileSizes[dim] *= warpSize; + break; + } + } + } + + auto redDim = getReductionDim(genericOp).value(); + tileSizes[redDim] = 0; + + std::vector fuseCandidates; + for (OpOperand *opOperand : genericOp.getDpsInitOperands()) { + ProducerSelector::detectFillOperand(opOperand, fuseCandidates); + } + + auto numTiledLoops = getNumTiledLoops(tileSizes); + if (numTiledLoops >= 1 && numTiledLoops <= 3) { + SmallVector mapping(numLoops, -1); + int64_t dimMapping = static_cast(gpu::Blocks::DimX); + for (auto &&affineMap : genericOp.getIndexingMapsArray()) { + if (affineMap.isPermutation()) { + for (int64_t i = numLoops - 1; i >= 0; i--) { + auto dim = affineMap.getDimPosition(i); + if (tileSizes[dim] > 0) { + mapping[dim] = dimMapping++; + } + } + break; + } + } + mapping.erase(std::remove(mapping.begin(), mapping.end(), -1), + mapping.end()); + if (mapping.size() != numTiledLoops) + return std::nullopt; + + return GridTileConfig{ + tileSizes, + llvm::to_vector(llvm::map_range( + mapping, [](int64_t i) { return static_cast(i); })), + fuseCandidates}; + } + return std::nullopt; +} + +std::optional getBlockSplitConfig(linalg::GenericOp genericOp, + int64_t splitFactor, + int64_t warpSize) { + if (!isBlockReductionOp(genericOp)) + return std::nullopt; + + SmallVector padDims = getDynamicDims(genericOp); + SmallVector padValues; + + SmallVector splitFactors; + SmallVector dimensions; + auto redDim = *getReductionDim(genericOp); + auto staticLoopRanges = genericOp.getStaticLoopRanges(); + if (ShapedType::isDynamic(staticLoopRanges[redDim])) + return std::nullopt; + + if (auto redPos = getOperandReductionDim(*genericOp.getDpsInputOperand(0))) { + if (redPos.value() == genericOp.getNumLoops() - 1) { + auto newSplitFactor = splitFactor * 2; + while (staticLoopRanges[redDim] % newSplitFactor == 0 && + newSplitFactor <= splitFactor * warpSize) { + newSplitFactor *= 2; + } + splitFactor = newSplitFactor / 2; + } + } + + if (staticLoopRanges[redDim] < splitFactor) { + splitFactor = staticLoopRanges[redDim]; + } else { + if (staticLoopRanges[redDim] % splitFactor != 0) + return std::nullopt; + + splitFactors.push_back(splitFactor); + dimensions.push_back(redDim ? redDim - 1 : redDim); + } + + mlir::Builder b(genericOp.getContext()); + for (auto &&operand : genericOp->getOperands()) { + if (auto shapedType = llvm::dyn_cast(operand.getType())) { + padValues.push_back(b.getZeroAttr(shapedType.getElementType())); + } else { + return std::nullopt; + } + } + + for (; splitFactor > 2; splitFactor >>= 1) { + splitFactors.push_back(splitFactor / 2); + dimensions.push_back(redDim ? redDim - 1 : redDim); + } + + return BlockSplitConfig{splitFactors, dimensions, padDims, padValues}; +} + +std::optional getBlockTileConfig(linalg::GenericOp genericOp, + int64_t warpSize, + int64_t blockSize) { + if (!isBlockReductionOp(genericOp)) + return std::nullopt; + + int64_t numLoops = genericOp.getNumLoops(); + SmallVector tileSizes(numLoops, 0); + auto loopSizes = + cast(genericOp.getOperation()).computeStaticLoopSizes(); + + int64_t remainBlockSize = blockSize; + auto redDim = getReductionDim(genericOp).value(); + for (int64_t idx = 0; idx < numLoops && remainBlockSize > 1; ++idx) { + if (idx == redDim) + continue; + int64_t curLoopSize2 = nextPowerOf2(loopSizes[idx]); + int64_t curBlockSize = std::min(curLoopSize2, remainBlockSize); + tileSizes[idx] = curLoopSize2 / curBlockSize; + remainBlockSize /= curBlockSize; + } + + if (remainBlockSize == blockSize) { + tileSizes[redDim] = loopSizes[redDim]; + } + + std::vector fuseCandidates; + for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { + ProducerSelector::detectPadOperand(opOperand, fuseCandidates); + } + for (OpOperand *opOperand : genericOp.getDpsInitOperands()) { + ProducerSelector::detectFillOperand(opOperand, fuseCandidates); + } + + auto numTiledLoops = getNumTiledLoops(tileSizes); + if (numTiledLoops >= 1 && numTiledLoops <= 3) { + SmallVector mapping(numLoops, -1); + int64_t dimMapping = static_cast(gpu::Threads::DimX); + for (auto &&affineMap : genericOp.getIndexingMapsArray()) { + if (affineMap.isPermutation()) { + for (int64_t i = numLoops - 1; i >= 0; i--) { + auto dim = affineMap.getDimPosition(i); + if (tileSizes[dim] > 0) { + mapping[dim] = dimMapping++; + } + } + break; + } + } + mapping.erase(std::remove(mapping.begin(), mapping.end(), -1), + mapping.end()); + if (mapping.size() != numTiledLoops) + return std::nullopt; + + return BlockTileConfig{ + tileSizes, + llvm::to_vector(llvm::map_range( + mapping, [](int64_t i) { return static_cast(i); })), + fuseCandidates}; + } + return std::nullopt; +} + +std::optional +getThreadTileConfig(linalg::GenericOp genericOp) { + if (!isThreadReductionOp(genericOp)) + return std::nullopt; + + int64_t numLoops = genericOp.getNumLoops(); + SmallVector parallelTileSizes(numLoops, 1); + SmallVector reductionTileSizes(numLoops, 0); + auto reductionDim = *getReductionDim(genericOp); + + parallelTileSizes[reductionDim] = 0; + reductionTileSizes[reductionDim] = 1; + + SmallVector unrollFactors = + cast(genericOp.getOperation()).computeStaticLoopSizes(); + + std::vector initOperands; + for (OpOperand *opOperand : genericOp.getDpsInitOperands()) { + ProducerSelector::detectFillOperand(opOperand, initOperands); + } + + return ThreadTileConfig{parallelTileSizes, reductionTileSizes, unrollFactors, + initOperands}; +} + +//----------------------------------------------------------------------------// +// transform insertion impl +//----------------------------------------------------------------------------// + +void createGPUSplitGridReductionTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix, + int64_t splitFactor) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (auto genericOp = llvm::dyn_cast_or_null(op)) { + return getGridSplitConfig(genericOp, splitFactor).has_value(); + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + auto splitConfig = + getGridSplitConfig(llvm::cast(op), splitFactor) + .value(); + splitConfig.apply(b, pdlV); + }; + + pm.addPass(createGenericTransformInsertionPass(config)); +} + +void createGPUTileGridReductionTransformImpl( + OpPassManager &pm, const std::string &anchor, const std::string &prefix, + int64_t warpSize, int64_t blockSize, bool usingForall) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (auto genericOp = llvm::dyn_cast_or_null(op)) { + return getGridTileConfig(genericOp, warpSize, blockSize).has_value(); + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + auto tileConfig = getGridTileConfig(llvm::cast(op), + warpSize, blockSize) + .value(); + tileConfig.apply(b, pdlV, usingForall); + }; + + pm.addPass(createGenericTransformInsertionPass(config)); +} + +void createGPUSplitBlockReductionTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix, + int64_t splitFactor, + int64_t warpSize) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (auto genericOp = llvm::dyn_cast_or_null(op)) { + return getBlockSplitConfig(genericOp, splitFactor, warpSize).has_value(); + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + auto splitConfig = getBlockSplitConfig(llvm::cast(op), + splitFactor, warpSize) + .value(); + splitConfig.apply(b, pdlV); + }; + + pm.addPass(createGenericTransformInsertionPass(config)); +} + +void createGPUTileBlockReductionTransformImpl( + OpPassManager &pm, const std::string &anchor, const std::string &prefix, + int64_t warpSize, int64_t blockSize, bool usingForall) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (auto genericOp = llvm::dyn_cast_or_null(op)) { + return getBlockTileConfig(genericOp, warpSize, blockSize).has_value(); + } else if (auto copyOp = llvm::dyn_cast_or_null(op)) { + return copyOp.getNumLoops() == 1; + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + if (auto genericOp = llvm::dyn_cast_or_null(op)) { + auto tileConfig = getBlockTileConfig(llvm::cast(op), + warpSize, blockSize) + .value(); + tileConfig.apply(b, pdlV, usingForall); + } else if (auto copyOp = llvm::dyn_cast_or_null(op)) { + auto tileOp = b.create( + /* target */ pdlV, + /* staticTileSizes */ SmallVector(1, blockSize), + /* ctor tag */ transform::NumThreadsSpec(), + /* mapping */ + b.getArrayAttr(gpu::GPULinearIdMappingAttr::get( + b.getContext(), gpu::LinearId::DimX))); + } + }; + + pm.addPass(createGenericTransformInsertionPass(config)); +} + +void createGPUTileThreadReductionTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (auto genericOp = llvm::dyn_cast_or_null(op)) { + return getThreadTileConfig(genericOp).has_value(); + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + auto tileConfig = + getThreadTileConfig(llvm::cast(op)).value(); + tileConfig.apply(b, pdlV); + }; + + pm.addPass(createGenericTransformInsertionPass(config)); +} +} // namespace + +void mlir::createGPUSplitGridReductionTransform( + OpPassManager &pm, const GPUSplitGridReductionOptions &options) { + invokeOpPassPipelineBuilder(createGPUSplitGridReductionTransformImpl, pm, + options.funcAnchor, options.annotatePrefix, + options.splitFactor); +} + +void mlir::createGPUTileGridReductionTransform( + OpPassManager &pm, const GPUTileGridReductionOptions &options) { + invokeOpPassPipelineBuilder(createGPUTileGridReductionTransformImpl, pm, + options.funcAnchor, options.annotatePrefix, + options.warpSize, options.blockSize, + options.usingForall); +} + +void mlir::createGPUSplitBlockReductionTransform( + OpPassManager &pm, const GPUSplitBlockReductionOptions &options) { + invokeOpPassPipelineBuilder(createGPUSplitBlockReductionTransformImpl, pm, + options.funcAnchor, options.annotatePrefix, + options.splitFactor, options.warpSize); +} + +void mlir::createGPUTileBlockReductionTransform( + OpPassManager &pm, const GPUTileBlockReductionOptions &options) { + invokeOpPassPipelineBuilder(createGPUTileBlockReductionTransformImpl, pm, + options.funcAnchor, options.annotatePrefix, + options.warpSize, options.blockSize, + options.usingForall); +} + +void mlir::createGPUTileThreadReductionTransform( + OpPassManager &pm, const GPUTileThreadReductionOptions &options) { + invokeOpPassPipelineBuilder(createGPUTileThreadReductionTransformImpl, pm, + options.funcAnchor, options.annotatePrefix); +} diff --git a/compiler/lib/Pipelines/HloOpt.cpp b/compiler/lib/Pipelines/HloOpt.cpp index b80e815de..7aa45267e 100644 --- a/compiler/lib/Pipelines/HloOpt.cpp +++ b/compiler/lib/Pipelines/HloOpt.cpp @@ -21,6 +21,7 @@ #include "byteir/Pipelines/Common/Utils.h" #include "byteir/Transforms/CanonicalizeExt.h" #include "mhlo/transforms/passes.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/Passes.h" @@ -50,6 +51,7 @@ void addGenericHloFusionPatterns(OpPassManager &pm, const std::string &entry, pm.addPass(createCSEPass()); pm.addNestedPass(createFlattenTuplePass()); + pm.addNestedPass(createReductionFusionPass()); // Element fusion (always last?) // Note: if outlineSingleElemwiseOp is set, element fusion must be the last // pass, since it will cluster every elemenwise op which is not fused yet into @@ -106,6 +108,7 @@ void createHloOptPipelineImpl(OpPassManager &pm, const std::string &entryFunc, pm.addPass(createCSEPass()); pm.addPass(createCanonicalizeExtPass()); pm.addPass(createSymbolDCEPass()); + pm.addPass(func::createDuplicateFunctionEliminationPass()); } } // namespace diff --git a/compiler/lib/Pipelines/LinalgMemrefOpt.cpp b/compiler/lib/Pipelines/LinalgMemrefOpt.cpp index f6a4d55b6..6b11d2503 100644 --- a/compiler/lib/Pipelines/LinalgMemrefOpt.cpp +++ b/compiler/lib/Pipelines/LinalgMemrefOpt.cpp @@ -33,7 +33,9 @@ void addGenericLinalgMemrefOptPasses(OpPassManager &pm) { pm.addPass(createMemrefCopyToLinalgPass( getAttrPlaceholderName( byre::ByreDialect::getEntryPointFunctionAttrName()), - getByteIRElementwiseFusionAttrName().str())); + getByteIRElementwiseFusionAttrName().str(), true)); + pm.addPass(createMemrefCopyToLinalgPass( + getByteIRReductionFusionAttrName().str(), "", false)); } void createLinalgMemrefOptPipelineImpl(OpPassManager &pm, diff --git a/compiler/lib/Pipelines/LinalgTensorOpt.cpp b/compiler/lib/Pipelines/LinalgTensorOpt.cpp index c220a382a..88c1ab2b9 100644 --- a/compiler/lib/Pipelines/LinalgTensorOpt.cpp +++ b/compiler/lib/Pipelines/LinalgTensorOpt.cpp @@ -17,58 +17,39 @@ #include "byteir/Pipelines/LinalgTensorOpt.h" #include "byteir/Pipelines/GPU/ElementwiseCodegen.h" +#include "byteir/Pipelines/GPU/ReductionCodegen.h" #include "byteir/Pipelines/Host/Codegen.h" #include "byteir/Conversion/ToLinalg/ToLinalg.h" #include "byteir/Dialect/Linalg/Passes.h" +#include "byteir/Dialect/Tensor/Passes.h" #include "byteir/Dialect/Transform/Transforms/TransformDialectInterpreter.h" +#include "byteir/Dialect/Transform/Transforms/TransformInsertion.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Dialect/mhlo/Transforms/HloFuser.h" #include "byteir/Pipelines/Common/Utils.h" +#include "byteir/Transforms/AnchoredPipeline.h" #include "byteir/Transforms/CanonicalizeExt.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Transforms/Passes.h" using namespace mlir; namespace { -void collectBroadcastOperands( - mlir::Operation *op, - mlir::DenseMap &collection) { - - auto tensorSlice = dyn_cast(op); - if (!tensorSlice) { - return; - } - - for (Value res : op->getResults()) { - bool isBroadcast = false; - for (auto &&use : res.getUses()) { - if (auto genericOp = dyn_cast(use.getOwner())) { - auto affineMap = - genericOp.getIndexingMapsArray()[use.getOperandNumber()]; - if (!affineMap.isPermutation() && - affineMap.isProjectedPermutation(/*allowZeroInResults*/ true)) { - isBroadcast = true; - } - } - } - if (isBroadcast) { - collection.insert(std::make_pair(res, std::make_pair(Attribute(), true))); - } - } -} - -void addGenericLinalgElementwisePasses(OpPassManager &pm) { +void addGenericLinalgPasses(OpPassManager &pm) { pm.addNestedPass( createHloFusionToLinalgPass(getByteIRElementwiseFusionAttrName())); + pm.addNestedPass( + createHloFusionToLinalgPass(getByteIRReductionFusionAttrName())); pm.addNestedPass(createUnrealizedCastToLinalgPass()); pm.addPass(createLinalgElementwiseFusionExtPass( /*enableSharedInput*/ true, /*enableDiffShapes*/ false)); pm.addPass(createCSEPass()); - { + { // elementwise codegen + auto elementwiseAnchor = getByteIRElementwiseFusionAttrName().str(); GPUTileElementwiseOptions options; - options.funcAnchor = getByteIRElementwiseFusionAttrName().str(); + options.funcAnchor = elementwiseAnchor; // set to 1 for fully fusion & unroll, and all tiled loops will be coalesced // and mapping to LinearIdx.x in later pipeline // FIXME: set to real blockSize and mapping tiled loops to the corresponding @@ -77,13 +58,98 @@ void addGenericLinalgElementwisePasses(OpPassManager &pm) { options.warpSize = 32; createGPUTileElementwiseTransform(pm, options); pm.addPass(createTransformDialectInterpreter(true)); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + anchoredPM.addPass(createLinalgElementwiseFusionExtPass( + /*enableSharedInput*/ true, /*enableDiffShapes*/ false)); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createCanonicalizerPass()); + pm.addNestedPass( + createAnchoredPipelinePass(elementwiseAnchor, anchoredPM)); + } + } + { // reduction codegen + auto reductionAnchor = getByteIRReductionFusionAttrName().str(); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass( + createLinalgCollapseLoops(utils::IteratorType::reduction)); + anchoredPM.addPass( + createLinalgCollapseLoops(utils::IteratorType::parallel)); + pm.addNestedPass( + createAnchoredPipelinePass(reductionAnchor, anchoredPM)); + } + + GPUSplitGridReductionOptions splitGridRedOptions; + splitGridRedOptions.funcAnchor = reductionAnchor; + createGPUSplitGridReductionTransform(pm, splitGridRedOptions); + pm.addPass(createTransformDialectInterpreter(true)); pm.addPass(createCanonicalizerPass()); + + GPUTileGridReductionOptions tileGridRedOptions; + tileGridRedOptions.funcAnchor = reductionAnchor; + tileGridRedOptions.blockSize = 512; + createGPUTileGridReductionTransform(pm, tileGridRedOptions); + pm.addPass(createTransformDialectInterpreter(true)); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + pm.addNestedPass( + createAnchoredPipelinePass(reductionAnchor, anchoredPM)); + } + + GPUSplitBlockReductionOptions splitBlockRedOptions; + splitBlockRedOptions.funcAnchor = reductionAnchor; + splitBlockRedOptions.splitFactor = 16; + createGPUSplitBlockReductionTransform(pm, splitBlockRedOptions); + pm.addPass(createTransformDialectInterpreter(true)); + pm.addPass(createCanonicalizerPass()); + + GPUTileBlockReductionOptions tileBlockRedOptions; + tileBlockRedOptions.funcAnchor = reductionAnchor; + tileBlockRedOptions.blockSize = 512; + createGPUTileBlockReductionTransform(pm, tileBlockRedOptions); + pm.addPass(createTransformDialectInterpreter(true)); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + pm.addNestedPass( + createAnchoredPipelinePass(reductionAnchor, anchoredPM)); + } + + GPUTileThreadReductionOptions tileThreadRedOptions; + tileThreadRedOptions.funcAnchor = reductionAnchor; + createGPUTileThreadReductionTransform(pm, tileThreadRedOptions); + pm.addPass(createTransformDialectInterpreter(true)); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + pm.addNestedPass( + createAnchoredPipelinePass(reductionAnchor, anchoredPM)); + } + + pm.addPass(createDetensorizeTransformInsertionPass(reductionAnchor)); + pm.addPass(createTransformDialectInterpreter(true)); + pm.addPass(createCanonicalizeExtPass()); + pm.addPass(createRewriteInDPSTransformInsertionPass(reductionAnchor)); + pm.addPass(createTransformDialectInterpreter(true)); + pm.addPass(createCanonicalizerPass()); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createTensorPadSpecializationPass()); + anchoredPM.addPass(bufferization::createEmptyTensorEliminationPass()); + pm.addNestedPass( + createAnchoredPipelinePass(reductionAnchor, anchoredPM)); + } } - pm.addPass(createLinalgFoldUnitExtentDimsPass()); - pm.addPass(createLinalgElementwiseFusionExtPass( - /*enableSharedInput*/ true, /*enableDiffShapes*/ false)); - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); } void addCPULinalgOptPasses(OpPassManager &pm) { @@ -110,7 +176,7 @@ void createLinalgTensorOptPipelineImpl(OpPassManager &pm, if (target == "CPU") { addCPULinalgOptPasses(pm); } else { - addGenericLinalgElementwisePasses(pm); + addGenericLinalgPasses(pm); } } } // namespace diff --git a/compiler/lib/Transforms/Bufferize.cpp b/compiler/lib/Transforms/Bufferize.cpp index c93e964ef..0d0d14c4a 100644 --- a/compiler/lib/Transforms/Bufferize.cpp +++ b/compiler/lib/Transforms/Bufferize.cpp @@ -17,7 +17,6 @@ #include "byteir/Transforms/Bufferize.h" -#include "./PassDetail.h" #include "byteir/Dialect/Ace/AceDialect.h" #include "byteir/Dialect/Byre/ByreDialect.h" #include "byteir/Dialect/Byre/Transforms/BufferizableOpInterfaceImpl.h" @@ -40,6 +39,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" @@ -62,6 +62,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "./PassDetail.h" + using namespace mlir; using namespace mlir::bufferization; @@ -93,6 +95,28 @@ struct OneShotBufferizePass vector::registerBufferizableOpInterfaceExternalModels(registry); } + static bool isGPUSharedMem(MemRefType type) { + if (auto memorySpace = llvm::dyn_cast_or_null( + type.getMemorySpace())) { + if (memorySpace.getValue() == + gpu::GPUDialect::getWorkgroupAddressSpace()) { + return true; + } + } + return false; + } + + template + static auto createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, size_t bufferAlignment) { + if (bufferAlignment != 0) + return b + .create(loc, type, dynShape, + b.getI64IntegerAttr(bufferAlignment)) + .getResult(); + return b.create(loc, type, dynShape).getResult(); + } + void runOnOperation() override { bufferization::OneShotBufferizationOptions opts; opts.allowReturnAllocs = true; @@ -101,6 +125,29 @@ struct OneShotBufferizePass bufferization::LayoutMapOption::IdentityLayoutMap); opts.createDeallocs = false; opts.bufferAlignment = 0; + opts.allocationFn = [](OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, + unsigned int bufferAlignment) -> FailureOr { + if (isGPUSharedMem(type)) { + return createAlloc(b, loc, type, dynShape, + bufferAlignment); + } + return createAlloc(b, loc, type, dynShape, + bufferAlignment); + }; + opts.deallocationFn = [](OpBuilder &b, Location loc, + Value allocatedBuffer) -> LogicalResult { + if (auto bufferType = + llvm::dyn_cast_or_null(allocatedBuffer.getType())) { + if (isGPUSharedMem(bufferType)) { + return success(); + } + } + + // Default buffer deallocation via DeallocOp. + b.create(loc, allocatedBuffer); + return success(); + }; // deny some corner cases opts.opFilter.denyOperation([&](Operation *op) { @@ -272,6 +319,180 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, return success(); } } // namespace CallOpBufferizableOpInterfacePatch + +// ------------------------------------------------------------------------ // +// Patch of TensorInsertOp +// ------------------------------------------------------------------------ // +namespace TensorInsertPatch { +bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) { + assert(isa(op) && + "expected that op implements DestinationStyleOpInterface"); + + if (opOperand.getOperandNumber() == 1 && + opOperand.get().getType().cast().getRank() == 0) { + return false; + } + + return true; +} + +} // namespace TensorInsertPatch + +template static bool overwriteEntireTensor(OpTy insertSliceOp) { + RankedTensorType destType = insertSliceOp.getDestType(); + // Dest is not read if it is entirely overwritten. E.g.: + // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> + bool allOffsetsZero = + llvm::all_of(insertSliceOp.getMixedOffsets(), + [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); + bool sizesMatchDestSizes = llvm::all_of( + llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) { + return getConstantIntValue(it.value()) == + destType.getDimSize(it.index()); + }); + bool allStridesOne = + llvm::all_of(insertSliceOp.getMixedStrides(), + [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }); + return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); +} + +/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. +/// equivalent operand / result and same offset/sizes/strides specification). +template +static bool areEquivalentSlices(const AnalysisState &state, + tensor::ExtractSliceOp extractSliceOp, + OpTy insertSliceOp) { + if (!extractSliceOp || !insertSliceOp) + return false; + if (extractSliceOp != insertSliceOp && + !state.areEquivalentBufferizedValues(extractSliceOp.getSource(), + insertSliceOp.getDest())) + return false; + if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp, + isEqualConstantIntOrValue)) + return false; + return true; +} + +/// Return true if `value` is originating from an ExtractSliceOp that matches +/// the given InsertSliceOp. +template +static bool matchesInsertDestination(const AnalysisState &state, Value value, + OpTy insertSliceOp) { + // Look for matching slices. + auto matchesSlice = [&](Value val) { + if (auto extractSliceOp = val.getDefiningOp()) + if (areEquivalentSlices(state, extractSliceOp, insertSliceOp)) + return true; + return false; + }; + return static_cast(llvm::all_of( + state.findValueInReverseUseDefChain(value, matchesSlice), matchesSlice)); +} + +template +static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead, + OpOperand *uConflictingWrite, + const AnalysisState &state) { + Operation *readingOp = uRead->getOwner(); + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If + // uRead is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(readingOp)) { + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + + // TODO: Use insertSliceOp.getDestOpOperand etc. when available. + if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && + matchesInsertDestination(state, uConflictingWrite->get(), + insertSliceOp)) + // Case 1: The main insight is that InsertSliceOp reads only part of + // the destination tensor. The overwritten area is not read. If + // uConflictingWrite writes into exactly the memory location that is + // being read by uRead, this is not a conflict. + // + // In the above example: + // uRead = OpOperand 1 (%t) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%0) of linalg.fill + // + // The read of %t does not conflict with the write of the FillOp + // (same aliases!) because the area that the FillOp operates on is + // exactly the one that is *not* read via %t. + return true; + + if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && + uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + (overwriteEntireTensor(insertSliceOp) || + matchesInsertDestination(state, uRead->get(), insertSliceOp))) + // Case 2: The read of the source tensor and the write to the dest + // tensor via an InsertSliceOp is not a conflict if the read is + // reading exactly that part of an equivalent tensor that the + // InsertSliceOp is writing. + // + // In the above example: + // uRead = OpOperand 0 (%1) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + return true; + } + + // If uConflictingWrite is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(conflictingWritingOp)) + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + // %3 = vector.transfer_read %1, %cst + // + // In the above example: + // uRead = OpOperand 0 (%1) of vector.transfer_read + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + // definition = %1 + // + // This is not a conflict because the InsertSliceOp overwrites the + // memory segment of %1 with the exact same data. (Effectively, there + // is no memory write here.) + if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + state.areEquivalentBufferizedValues(uRead->get(), + insertSliceOp.getSource()) && + matchesInsertDestination(state, insertSliceOp.getSource(), + insertSliceOp)) + return true; + + return false; +} + +// ------------------------------------------------------------------------ // +// Patch of TensorParallelInsertSlice +// ------------------------------------------------------------------------ // +namespace TensorParallelInsertSlicePatch { +bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) { + auto insertSliceOp = cast(op); + + // The source is always read. + if (&opOperand == &op->getOpOperand(0) /*src*/) + return true; + + // For the destination, it depends... + assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest"); + + return overwriteEntireTensor(insertSliceOp); +} +bool isNotConflicting(Operation *op, OpOperand *uRead, + OpOperand *uConflictingWrite, + const AnalysisState &state) { + return isNotConflictingInsertSliceLikeOp( + op, uRead, uConflictingWrite, state); +} +} // namespace TensorParallelInsertSlicePatch } // namespace // TODO: removed this once upstrem fixed it @@ -279,6 +500,21 @@ RegisterOpInterfaceOverride( /*Op=*/func::CallOp, /*Interface=*/BufferizableOpInterface, /*InterfaceMethod=*/bufferize, /*Impl=*/&CallOpBufferizableOpInterfacePatch::bufferize); +RegisterOpInterfaceOverride( + /*Op=*/tensor::InsertOp, /*Interface=*/BufferizableOpInterface, + /*InterfaceMethod=*/bufferizesToMemoryRead, + /*Impl=*/ + &TensorInsertPatch::bufferizesToMemoryRead); +RegisterOpInterfaceOverride( + /*Op=*/tensor::ParallelInsertSliceOp, /*Interface=*/BufferizableOpInterface, + /*InterfaceMethod=*/bufferizesToMemoryRead, + /*Impl=*/ + &TensorParallelInsertSlicePatch::bufferizesToMemoryRead); +RegisterOpInterfaceOverride( + /*Op=*/tensor::ParallelInsertSliceOp, /*Interface=*/BufferizableOpInterface, + /*InterfaceMethod=*/isNotConflicting, + /*Impl=*/ + &TensorParallelInsertSlicePatch::isNotConflicting); std::unique_ptr> byteir::createOneShotBufferizePass() { diff --git a/compiler/lib/Transforms/MemoryPlanning.cpp b/compiler/lib/Transforms/MemoryPlanning.cpp index b00ae8e8b..70e9cce2c 100644 --- a/compiler/lib/Transforms/MemoryPlanning.cpp +++ b/compiler/lib/Transforms/MemoryPlanning.cpp @@ -220,10 +220,11 @@ template class SortedPackingStrategy { : windowSize(windowSize), alignment(alignment), compare(compare) {} /// Optimize the buffer allocations. - void optimze(const mlir::bufferization::BufferPlacementAllocs &allocs, - const UserangeAnalysis &userangeAnalysis, - std::vector &packedBuffers, - std::function isValidAllocation) { + void optimze( + const mlir::bufferization::BufferPlacementAllocs::AllocEntryList &allocs, + const UserangeAnalysis &userangeAnalysis, + std::vector &packedBuffers, + std::function isValidAllocation) { AllocInfoList allocInfos; allocInfos.reserve(std::distance(allocs.begin(), allocs.end())); @@ -344,7 +345,7 @@ template class SortedPackingStrategy { /// maximal userange. size_t computeAllocationInfos( AllocInfoList &allocInfos, const UserangeAnalysis &userangeAnalysis, - const mlir::bufferization::BufferPlacementAllocs &allocs, + const mlir::bufferization::BufferPlacementAllocs::AllocEntryList &allocs, std::function isValidAllocation) { // Create allocInformations and store them in allocInfos. size_t maxUserangeId = 0; @@ -405,13 +406,15 @@ template class SortedPackingStrategy { /// argument. template class BufferPacking : bufferization::BufferPlacementTransformationBase { + static constexpr bool is_alloca = std::is_same_v; + public: template BufferPacking(Operation *op, StrategyT strategy, std::function couldReuseAllocation) : BufferPlacementTransformationBase(op), liveness(op), - userangeAnalysis(op, &liveness, initAllocs(op), aliases), - dominators(op) { + allocs(initAllocs(op)), + userangeAnalysis(op, &liveness, allocs, aliases), dominators(op) { std::vector packedBuffers; strategy.optimze(allocs, userangeAnalysis, packedBuffers, couldReuseAllocation); @@ -434,6 +437,7 @@ class BufferPacking : bufferization::BufferPlacementTransformationBase { private: byteir::Liveness liveness; + bufferization::BufferPlacementAllocs::AllocEntryList allocs; UserangeAnalysis userangeAnalysis; /// The current dominance info. DominanceInfo dominators; @@ -451,13 +455,18 @@ class BufferPacking : bufferization::BufferPlacementTransformationBase { dominators); } - const bufferization::BufferPlacementAllocs &initAllocs(Operation *op) { + bufferization::BufferPlacementAllocs::AllocEntryList + initAllocs(Operation *op) { if constexpr (std::is_same_v) { + bufferization::BufferPlacementAllocs::AllocEntryList ret; op->walk([&](memref::AllocaOp alloca) { - allocs.registerAlloc({alloca.getResult(), nullptr}); + ret.emplace_back(alloca.getResult(), nullptr); }); + return ret; + } else { + auto &&baseAllocs = BufferPlacementTransformationBase::allocs; + return {baseAllocs.begin(), baseAllocs.end()}; } - return allocs; } void createBufferAndViews(const PackedBuffer &packedBuffer) { @@ -505,7 +514,7 @@ class BufferPacking : bufferization::BufferPlacementTransformationBase { }; template -inline void doBufferPacking(mlir::func::FuncOp func, size_t alignment, +inline void doBufferPacking(FunctionOpInterface func, size_t alignment, std::function couldReuseAllocation) { SortedPackingStrategy strategy( 0, // windowSize @@ -517,10 +526,12 @@ inline void doBufferPacking(mlir::func::FuncOp func, size_t alignment, struct MemoryPlanningPass : public MemoryPlanningBase { MemoryPlanningPass() = default; - MemoryPlanningPass(size_t alignment, + MemoryPlanningPass(size_t alignment, bool alloca, size_t memSpace, std::function couldReuseAllocation) : MemoryPlanningBase() { this->alignment = alignment; + this->alloca = alloca; + this->memSpace = memSpace; this->couldReuseAllocation = couldReuseAllocation; } @@ -559,11 +570,15 @@ struct MemoryPlanningPass : public MemoryPlanningBase { }; } // namespace -std::unique_ptr> mlir::createMemoryPlanningPass() { +std::unique_ptr> +mlir::createMemoryPlanningPass() { return std::make_unique(); } -std::unique_ptr> mlir::createMemoryPlanningPass( - size_t alignment, std::function couldReuseAllocation) { - return std::make_unique(alignment, couldReuseAllocation); +std::unique_ptr> +mlir::createMemoryPlanningPass( + size_t alignment, bool alloca, size_t memSpace, + std::function couldReuseAllocation) { + return std::make_unique(alignment, alloca, memSpace, + couldReuseAllocation); } diff --git a/compiler/lib/Utils/Utils.cpp b/compiler/lib/Utils/Utils.cpp index aae017171..0dc02bb91 100644 --- a/compiler/lib/Utils/Utils.cpp +++ b/compiler/lib/Utils/Utils.cpp @@ -452,12 +452,13 @@ Value mlir::getSlice(OpBuilder &b, Location loc, Value source, OpFoldResult mlir::canonicalizeOpFoldResult(OpFoldResult ofr, bool enableFold) { if (auto val = ofr.dyn_cast()) { - SmallVector foldResults; if (enableFold) { - OpBuilder builder(val.getContext()); - Operation *op = val.getDefiningOp(); - if (op && !failed(builder.tryFold(val.getDefiningOp(), foldResults))) { - val = foldResults[0]; + if (auto opResult = llvm::dyn_cast(val)) { + OpBuilder builder(opResult.getOwner()); + SmallVector foldResults; + if (!failed(builder.tryFold(opResult.getOwner(), foldResults))) { + val = foldResults[opResult.getResultNumber()]; + } } } return getAsOpFoldResult(val); diff --git a/compiler/numerical/hlo/canonicalize_ext.mlir b/compiler/numerical/hlo/canonicalize_ext.mlir index 924cfa162..dd961caee 100644 --- a/compiler/numerical/hlo/canonicalize_ext.mlir +++ b/compiler/numerical/hlo/canonicalize_ext.mlir @@ -273,3 +273,40 @@ func.func @fold_large_constant_reverse_float_1(%arg0: tensor<1x3x3x128x64xf32>) // CHECK-NEXT: mhlo.constant // CHECK-NEXT: mhlo.constant // CHECK-NOT: mhlo.reverse + +func.func @replace_gather_with_input_0() -> (tensor<1x64x128xf16>, tensor<1x32x64x128xf16>) { + %0 = mhlo.constant dense<1.000000e+00> : tensor<64x128xf16> + %1 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64xi64> + %2 = "mhlo.gather"(%0, %1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 128]> : tensor<2xi64>} : (tensor<64x128xf16>, tensor<64xi64>) -> tensor<64x128xf16> + %3 = mhlo.reshape %2 : (tensor<64x128xf16>) -> tensor<1x64x128xf16> + %4 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<64x128xf16>) -> tensor<1x32x64x128xf16> + return %3, %4 : tensor<1x64x128xf16>, tensor<1x32x64x128xf16> +} +// CHECK-LABEL: @replace_gather_with_input_0 +// CHECK-NEXT: mhlo.constant +// CHECK-NEXT: mhlo.constant +// CHECK-NEXT: return + +func.func @replace_gather_with_input_1(%arg0: tensor<64x128xf16>) -> (tensor<1x64x128xf16>, tensor<1x32x64x128xf16>) { + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64xi64> + %1 = "mhlo.gather"(%arg0, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 128]> : tensor<2xi64>} : (tensor<64x128xf16>, tensor<64xi64>) -> tensor<64x128xf16> + %2 = mhlo.reshape %1 : (tensor<64x128xf16>) -> tensor<1x64x128xf16> + %3 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<64x128xf16>) -> tensor<1x32x64x128xf16> + return %2, %3 : tensor<1x64x128xf16>, tensor<1x32x64x128xf16> +} +// CHECK-LABEL: @replace_gather_with_input_1 +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: mhlo.broadcast_in_dim +// CHECK-NEXT: return + +func.func @replace_gather_with_input_2(%arg0: tensor<64x128xf16>) -> (tensor<1x64x128xf16>, tensor<1x32x64x128xf16>) { + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<128xi64> + %1 = "mhlo.gather"(%arg0, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[64, 1]> : tensor<2xi64>} : (tensor<64x128xf16>, tensor<128xi64>) -> tensor<64x128xf16> + %2 = mhlo.reshape %1 : (tensor<64x128xf16>) -> tensor<1x64x128xf16> + %3 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<64x128xf16>) -> tensor<1x32x64x128xf16> + return %2, %3 : tensor<1x64x128xf16>, tensor<1x32x64x128xf16> +} +// CHECK-LABEL: @replace_gather_with_input_2 +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: mhlo.broadcast_in_dim +// CHECK-NEXT: return diff --git a/compiler/numerical/hlo/hlo_move_down.mlir b/compiler/numerical/hlo/hlo_move_down.mlir index 0676ae6f5..0afdd1ce3 100644 --- a/compiler/numerical/hlo/hlo_move_down.mlir +++ b/compiler/numerical/hlo/hlo_move_down.mlir @@ -179,3 +179,25 @@ func.func @broadcast_reshape_dot_with_concat_and_add(%arg0 : tensor<1x64xf16>, % // CHECK-NEXT: mhlo.broadcast_in_dim // CHECK-NEXT: return +func.func @transpose_move_down_binary_case0(%arg0 : tensor<2x128x4x16xf32>) -> tensor<2x4x16x128xf32> { + %0 = mhlo.constant dense<"0xtensor<2x4x16x128xf32> + %1 = "mhlo.transpose"(%arg0) {permutation = dense<[0,2,3,1]> : tensor<4xi64>} : (tensor<2x128x4x16xf32>) -> tensor<2x4x16x128xf32> + %2 = mhlo.add %1, %0 : tensor<2x4x16x128xf32> + return %2 : tensor<2x4x16x128xf32> +} +// CHECK-LABEL: func.func @transpose_move_down_binary_case0 +// CHECK-NEXT: mhlo.constant +// CHECK-NEXT: mhlo.add +// CHECK-NEXT: mhlo.transpose +// CHECK-NEXT: return + +func.func @transpose_move_down_binary_case1(%arg0 : tensor<2x128x4x16xf32>,%arg1 : tensor<2x128x4x16xf32>) -> tensor<2x4x16x128xf32> { + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[0,2,3,1]> : tensor<4xi64>} : (tensor<2x128x4x16xf32>) -> tensor<2x4x16x128xf32> + %1 = "mhlo.transpose"(%arg1) {permutation = dense<[0,2,3,1]> : tensor<4xi64>} : (tensor<2x128x4x16xf32>) -> tensor<2x4x16x128xf32> + %2 = mhlo.add %1, %0 : tensor<2x4x16x128xf32> + return %2 : tensor<2x4x16x128xf32> +} +// CHECK-LABEL: func.func @transpose_move_down_binary_case1 +// CHECK-NEXT: mhlo.add +// CHECK-NEXT: mhlo.transpose +// CHECK-NEXT: return diff --git a/compiler/python/ByteIRModules.cpp b/compiler/python/ByteIRModules.cpp index 3ba11567f..076e24c6d 100644 --- a/compiler/python/ByteIRModules.cpp +++ b/compiler/python/ByteIRModules.cpp @@ -15,6 +15,7 @@ // //===----------------------------------------------------------------------===// +#include "bindings/c/Passes.h" #include "byteir-c/Dialects.h" #include "byteir-c/Passes.h" #include "byteir-c/Translation.h" @@ -30,6 +31,7 @@ static MlirStringRef toMlirStringRef(const std::string &s) { PYBIND11_MODULE(_byteir, m) { byteirRegisterAllPasses(); + mlirRegisterAllMhloPasses(); m.doc() = "byteir python extension"; diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index 70c1bc1f8..297beaeb3 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -88,6 +88,11 @@ def compile_cuda( _print_verbose(module, "// IR Dump After GPU Opt:") with context: PassManager.parse("builtin.module(func.func(remove-func-body{anchor-attr=__byteir_elementwise_fusion__}))").run(module.operation) + PassManager.parse("builtin.module(inline)").run(module.operation) + if useBarePtrCallConv: + PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre{use-bare-ptr-memref-call-conv=true}))").run(module.operation) + else: + PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(module.operation) PassManager.parse("builtin.module(func.func(set-op-space{" + entry_func_str + " space={}".format(target) + "}))").run(module.operation) PassManager.parse("builtin.module(set-arg-space{" + entry_func_str + " all-space={}".format(target) + "})").run(module.operation) if verbose: @@ -127,7 +132,7 @@ def compile_cuda_with_ait( name: str = "model", aggressive_mode: bool = False, parallelism: int = 1, - disable_ait_cache: bool = False, + disable_byteir_cache: bool = False, **kwargs, ): target = "cuda" @@ -143,7 +148,7 @@ def compile_cuda_with_ait( processor = IRProcessor(name, "./workspace", compile_parallelism=parallelism, - disable_ait_cache=disable_ait_cache, + disable_byteir_cache=disable_byteir_cache, verbose=verbose) with context: processor.load_from_file(input) @@ -202,6 +207,11 @@ def compile_cuda_with_ait( _print_verbose(processor.module, "// IR Dump After GPU Opt:") with context: PassManager.parse("builtin.module(func.func(remove-func-body{anchor-attr=__byteir_elementwise_fusion__}))").run(processor.module.operation) + PassManager.parse("builtin.module(inline)").run(processor.module.operation) + if useBarePtrCallConv: + PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre{use-bare-ptr-memref-call-conv=true}))").run(processor.module.operation) + else: + PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(processor.module.operation) PassManager.parse("builtin.module(func.func(set-op-space{" + entry_func_str + " space={}".format(target) + "}))").run(processor.module.operation) PassManager.parse("builtin.module(set-arg-space{" + entry_func_str + " all-space={}".format(target) + "})").run(processor.module.operation) if verbose: @@ -241,7 +251,7 @@ def compile( target: str = "cuda", verbose: bool = False, parallelism: int = 1, - disable_ait_cache: bool = False, + disable_byteir_cache: bool = False, **kwargs, ): if target == "cuda": @@ -252,7 +262,7 @@ def compile( entry_func, verbose, parallelism=parallelism, - disable_ait_cache=disable_ait_cache) + disable_byteir_cache=disable_byteir_cache) elif target == "cuda_with_ait_aggressive": compile_cuda_with_ait(input, output, @@ -260,6 +270,6 @@ def compile( verbose, aggressive_mode=True, parallelism=parallelism, - disable_ait_cache=disable_ait_cache) + disable_byteir_cache=disable_byteir_cache) else: raise NotImplemented("not implemented target: {}".format(target)) diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index a397eff0d..c1388385b 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -19,7 +19,6 @@ def func_hash_str(func, gpu_type): hash_str = gpu_type + "_" ops = func.entry_block.operations - # assert len(ops) == 2 for op in ops: hash_str += f"{op.get_asm(large_elements_limit=None)};" return hash_str @@ -29,7 +28,7 @@ def __init__(self, job_name, workdir, compile_parallelism = MAX_COMPILATION_PARALLELISM, - disable_ait_cache = False, + disable_byteir_cache = False, verbose = False): self.job_name = job_name self.workdir = workdir @@ -37,14 +36,11 @@ def __init__(self, self.ait_reuse_recorder = {} # key: hash str, value: Tuple(dll_name, ait_module_path) self.compile_parallelism = min(compile_parallelism, MAX_COMPILATION_PARALLELISM) self.pool = multiprocessing.Pool(compile_parallelism) - self.ait_cache = AITCache() + self.byteir_cache = AITCache() self.verbose = verbose - # ait_cache is enabled when ait_reuse is enabled - # in other words, once `ait_reuse` is set to False, - # we will orcely compile all ait ops with bo reuse or cache. - self.disable_ait_cache = disable_ait_cache - if not disable_ait_cache: - self.ait_cache.load_or_create_cache() + self.disable_byteir_cache = disable_byteir_cache + if not disable_byteir_cache: + self.byteir_cache.load_or_create_cache() def _get_builder(self, module, subgraph_name, backend="ait"): assert module != None @@ -139,11 +135,11 @@ def ait_opt_pass(self, anchor_only=False, dump_ir=False): self.ait_reuse_recorder[hash_str] = (builder.dll_name, builder.ait_module_path) libs_to_add_to_cache[hash_str] = builder.ait_module_path dedup_work_items.append((hash_str, func_ir_str)) - - # search in ait cache + + # search in byteir cache work_items_not_in_cache = [] for hash_str, func_ir_str in dedup_work_items: - cached_lib = self.ait_cache.find(gpu_type, hash_str) + cached_lib = self.byteir_cache.find(gpu_type, hash_str) if cached_lib != None: # hit, copy cached lib context = ir.Context() @@ -171,12 +167,12 @@ def ait_opt_pass(self, anchor_only=False, dump_ir=False): t_ed = time.time() print("compilation finished in {}s".format(t_ed-t_st)) - # update ait cache - if not self.disable_ait_cache: + # update byteir cache + if not self.disable_byteir_cache: for key, lib_path in libs_to_add_to_cache.items(): - self.ait_cache.add(gpu_type, key, lib_path, override=False) - self.ait_cache._save() - self.ait_cache.close_cache() + self.byteir_cache.add(gpu_type, key, lib_path, override=False) + self.byteir_cache._save() + self.byteir_cache.close_cache() with self.module.context: pm = PassManager.parse("builtin.module(func.func(gen-ait-config{{func-names={} ait-lib-paths={}}}))".format(funcNameArg, aitLibPathArg)) diff --git a/compiler/python/byteir/tools/compiler.py b/compiler/python/byteir/tools/compiler.py index 5385614af..51aedc583 100644 --- a/compiler/python/byteir/tools/compiler.py +++ b/compiler/python/byteir/tools/compiler.py @@ -11,7 +11,7 @@ parser.add_argument("--target", type=str, default="cuda", help="target device name") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--ait_parallelism", type=int, default=1, help="number of processes to compile ait op") - parser.add_argument("--disable_ait_cache", action="store_true") + parser.add_argument("--disable_byteir_cache", action="store_true") args = parser.parse_args() byteir.compile(args.input_mlir_path, @@ -20,6 +20,6 @@ args.target, args.verbose, args.ait_parallelism, - args.disable_ait_cache) + args.disable_byteir_cache) diff --git a/compiler/test/Conversion/HloToCat/fused_ops.mlir b/compiler/test/Conversion/HloToCat/fused_ops.mlir index 0b84fc4d5..92e0f2d1c 100644 --- a/compiler/test/Conversion/HloToCat/fused_ops.mlir +++ b/compiler/test/Conversion/HloToCat/fused_ops.mlir @@ -61,13 +61,22 @@ func.func @test_bmm_rcr_permute(%arg0: tensor<384x256x256xf32>, %arg1: tensor<38 // CHECK-NEXT: cat.bmm_rcr_permute // CHECK-NEXT: return +func.func @test_not_bmm_rrr_permute(%arg0: tensor<1x64x4096xf32>, %arg1: tensor<1x4096x4096xf32>) -> tensor<1x32x64x128xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x64x4096xf32>, tensor<1x4096x4096xf32>) -> tensor<1x64x4096xf32> + %1 = mhlo.reshape %0 : (tensor<1x64x4096xf32>) -> tensor<1x64x32x128xf32> + %2 = "mhlo.transpose"(%1) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x64x32x128xf32>) -> tensor<1x32x64x128xf32> + return %2 : tensor<1x32x64x128xf32> +} +// CHECK-LABEL: func.func @test_not_bmm_rrr_permute +// CHECK-NOT: cat.bmm_rrr_permute + func.func @test_bmm_rrr_add_0(%arg0: tensor<384x256x256xf32>, %arg1: tensor<384x256x64xf32>, %arg2: tensor<384x256x64xf32>) -> tensor<384x256x64xf32> { %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot} : (tensor<384x256x256xf32>, tensor<384x256x64xf32>) -> tensor<384x256x64xf32> %1 = mhlo.add %0, %arg2 : (tensor<384x256x64xf32>, tensor<384x256x64xf32>) -> tensor<384x256x64xf32> return %1 : tensor<384x256x64xf32> } -// CHECK: func.func @test_bmm_rrr_add_0 +// CHECK-LABEL: func.func @test_bmm_rrr_add_0 // CHECK-NEXT: cat.bmm_rrr_add // CHECK-NEXT: return @@ -77,7 +86,7 @@ func.func @test_bmm_rrr_add_1(%arg0: tensor<384x256x256xf32>, %arg1: tensor<384x return %1 : tensor<384x256x64xf32> } -// CHECK: func.func @test_bmm_rrr_add_1 +// CHECK-LABEL: func.func @test_bmm_rrr_add_1 // CHECK-NEXT: cat.bmm_rrr_add // CHECK-NEXT: return @@ -88,7 +97,7 @@ func.func @test_bmm_crr_add(%arg0: tensor<384x256x256xf32>, %arg1: tensor<384x25 return %2 : tensor<384x256x64xf32> } -// CHECK: func.func @test_bmm_crr_add +// CHECK-LABEL: func.func @test_bmm_crr_add // CHECK-NEXT: cat.bmm_crr_add // CHECK-NEXT: return @@ -205,7 +214,7 @@ func.func @test_gemm_rrr_bias(%arg0: tensor<2x2048xf32>, %arg1: tensor<2048x1001 return %3 : tensor<2x1001xf32> } -// CHECK: func.func @test_gemm_rrr_bias +// CHECK-LABEL: func.func @test_gemm_rrr_bias // CHECK-NEXT: mhlo.constant // CHECK-NEXT: cat.gemm_rrr_bias // CHECK-NEXT: return @@ -216,7 +225,7 @@ func.func @test_bmm_crc(%arg0: tensor<512x1024x128xf16>, %arg1: tensor<512x1024x return %1 : tensor<512x1024x128xf16> } -// CHECK: func.func @test_bmm_crc +// CHECK-LABEL: func.func @test_bmm_crc // CHECK-NEXT: cat.bmm_crc // CHECK-NEXT: return @@ -226,7 +235,7 @@ func.func @test_bmm_rrc(%arg0: tensor<512x128x1024xf16>, %arg1: tensor<512x1024x return %1 : tensor<512x1024x128xf16> } -// CHECK: func.func @test_bmm_rrc +// CHECK-LABEL: func.func @test_bmm_rrc // CHECK-NEXT: cat.bmm_rrc // CHECK-NEXT: return @@ -237,7 +246,7 @@ func.func @test_transpose_reshape_bmm_rrr_to_reshape_bmm_rcr(%arg0: tensor<64x12 return %2 : tensor<64x128x128xf16> } -// CHECK: func.func @test_transpose_reshape_bmm_rrr_to_reshape_bmm_rcr +// CHECK-LABEL: func.func @test_transpose_reshape_bmm_rrr_to_reshape_bmm_rcr // CHECK-NEXT: mhlo.reshape // CHECK-NEXT: cat.bmm_rcr // CHECK-NEXT: return @@ -248,7 +257,7 @@ func.func @test_bmm_rrr_reshape_transpose_to_bmm_rrc_reshape(%arg0: tensor<64x12 %2 = "mhlo.transpose"(%1) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<2x32x128x128xf16>) -> tensor<2x32x128x128xf16> return %2 : tensor<2x32x128x128xf16> } -// CHECK: func.func @test_bmm_rrr_reshape_transpose_to_bmm_rrc_reshape +// CHECK-LABEL: func.func @test_bmm_rrr_reshape_transpose_to_bmm_rrc_reshape // CHECK-NEXT: cat.bmm_rrc // CHECK-NEXT: mhlo.reshape // CHECK-NEXT: return @@ -259,7 +268,7 @@ func.func @test_bmm_crr_reshape_transpose_to_bmm_crc_reshape(%arg0: tensor<512x1 %2 = "mhlo.transpose"(%1) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<16x32x128x128xf16>) -> tensor<16x32x128x128xf16> return %2 : tensor<16x32x128x128xf16> } -// CHECK: func.func @test_bmm_crr_reshape_transpose_to_bmm_crc_reshape +// CHECK-LABEL: func.func @test_bmm_crr_reshape_transpose_to_bmm_crc_reshape // CHECK-NEXT: cat.bmm_crc // CHECK-NEXT: mhlo.reshape // CHECK-NEXT: return @@ -268,8 +277,43 @@ func.func @test_softmax_f16(%arg0 : tensor<1x12x1024x1024xf16>) -> tensor<1x12x1 %0 = mhlo.custom_call @byteir.softmax(%arg0) {backend_config = "", byteir_attrs = {axis = 3 : i64}} : (tensor<1x12x1024x1024xf16>) -> tensor<1x12x1024x1024xf32> return %0 : tensor<1x12x1024x1024xf32> } - -// CHECK: func.func @test_softmax_f16 +// CHECK-LABEL: func.func @test_softmax_f16 // CHECK-NEXT: cat.softmax // CHECK-NEXT: mhlo.convert // CHECK-NEXT: return + +func.func @test_bmm_rrr_broadcast_to_reshape_gemm_rrr_reshape(%arg0: tensor<16x1024x4096xf16>, %arg1: tensor<4096x4096xf16>) -> tensor<16x1024x4096xf16> { + %0 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4096x4096xf16>) -> tensor<16x4096x4096xf16> + %1 = "cat.bmm_rrr"(%arg0, %0) : (tensor<16x1024x4096xf16>, tensor<16x4096x4096xf16>) -> tensor<16x1024x4096xf16> + return %1 : tensor<16x1024x4096xf16> +} +// CHECK-LABEL: func.func @test_bmm_rrr_broadcast_to_reshape_gemm_rrr_reshape +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: cat.gemm_rrr +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: return + +func.func @test_transpose_bmm_rrr_broadcast_to_gemm_rrr_permute(%arg0: tensor<16x1024x4096xf16>, %arg1: tensor<4096x4096xf16>) -> tensor<16x32x1024x128xf16> { + %0 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4096x4096xf16>) -> tensor<16x4096x4096xf16> + %1 = "cat.bmm_rrr"(%arg0, %0) : (tensor<16x1024x4096xf16>, tensor<16x4096x4096xf16>) -> tensor<16x1024x4096xf16> + %2 = mhlo.reshape %1 : (tensor<16x1024x4096xf16>) -> tensor<16x1024x32x128xf16> + %3 = "mhlo.transpose"(%2) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<16x1024x32x128xf16>) -> tensor<16x32x1024x128xf16> + return %3 : tensor<16x32x1024x128xf16> +} +// CHECK-LABEL: func.func @test_transpose_bmm_rrr_broadcast_to_gemm_rrr_permute +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: cat.gemm_rrr_permute +// CHECK-NEXT: return + +func.func @test_transpose_bmm_rrr_broadcast_to_gemm_rcr_permute(%arg0: tensor<16x1024x4096xf16>, %arg1: tensor<4096x4096xf16>) -> tensor<16x32x1024x128xf16> { + %t = "mhlo.transpose"(%arg1) {permutation = dense<[1,0]> : tensor<2xi64>} : (tensor<4096x4096xf16>) -> tensor<4096x4096xf16> + %0 = "mhlo.broadcast_in_dim"(%t) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4096x4096xf16>) -> tensor<16x4096x4096xf16> + %1 = "cat.bmm_rrr"(%arg0, %0) : (tensor<16x1024x4096xf16>, tensor<16x4096x4096xf16>) -> tensor<16x1024x4096xf16> + %2 = mhlo.reshape %1 : (tensor<16x1024x4096xf16>) -> tensor<16x1024x32x128xf16> + %3 = "mhlo.transpose"(%2) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<16x1024x32x128xf16>) -> tensor<16x32x1024x128xf16> + return %3 : tensor<16x32x1024x128xf16> +} +// CHECK-LABEL: func.func @test_transpose_bmm_rrr_broadcast_to_gemm_rcr_permute +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: cat.gemm_rcr_permute +// CHECK-NEXT: return \ No newline at end of file diff --git a/compiler/test/Conversion/ToByre/convertMemRefToByre.mlir b/compiler/test/Conversion/ToByre/convertMemRefToByre.mlir index 9eec21e83..4dfa8f99a 100644 --- a/compiler/test/Conversion/ToByre/convertMemRefToByre.mlir +++ b/compiler/test/Conversion/ToByre/convertMemRefToByre.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt -convert-lmhlo-to-byre %s | FileCheck %s +// RUN: byteir-opt -memref-to-byre --split-input-file %s | FileCheck %s module attributes {byre.container_module} { // CHECK: module attributes {byre.container_module} { @@ -8,4 +8,22 @@ module attributes {byre.container_module} { // CHECK: byre.copy(%arg0, %alloc) {callee = "cpu2gpu"} : memref<4xf32, "cpu">, memref<4xf32, "gpu"> return } -} \ No newline at end of file +} + +// ----- + +module attributes {byre.container_module} { +// CHECK: module attributes {byre.container_module} { + func.func @forward(%arg0: memref {byre.argname = "A", byre.argtype = 1 : i32}, %arg1: memref<2xi64, "cuda"> {byre.argname = "Out", byre.argtype = 2 : i32}) attributes { byre.entry_point } { + %expand_shape = memref.expand_shape %arg0 [] : memref into memref<1xi64, "cuda"> + // CHECK: byre.alias + %alloc = memref.alloc() : memref<2xi64, "cuda"> + %subview = memref.subview %alloc[0] [1] [1] : memref<2xi64, "cuda"> to memref<1xi64, strided<[1]>, "cuda"> + // CHECK: byre.alias + memref.copy %expand_shape, %subview : memref<1xi64, "cuda"> to memref<1xi64, strided<[1]>, "cuda"> + // CHECK: byre.copy + memref.copy %alloc, %arg1 : memref<2xi64, "cuda"> to memref<2xi64, "cuda"> + // CHECK: byre.copy + return + } +} diff --git a/compiler/test/Dialect/Linalg/transform-op-fold-unit-extent-dims.mlir b/compiler/test/Dialect/Linalg/transform-op-fold-unit-extent-dims.mlir index 8759d25ef..f8e0c1992 100644 --- a/compiler/test/Dialect/Linalg/transform-op-fold-unit-extent-dims.mlir +++ b/compiler/test/Dialect/Linalg/transform-op-fold-unit-extent-dims.mlir @@ -18,6 +18,8 @@ func.func @tensor_collapse(%arg0 : tensor<12x1024x1024xf32>, %arg1 : tensor<1x10 transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.fold_unit_extent_dims %0 + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns to %0 { + transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes + } : !pdl.operation } diff --git a/compiler/test/Dialect/Mhlo/transforms/ConvertOpToCustomCall.mlir b/compiler/test/Dialect/Mhlo/transforms/ConvertOpToCustomCall.mlir index 1b13f2851..2ed4c0d77 100644 --- a/compiler/test/Dialect/Mhlo/transforms/ConvertOpToCustomCall.mlir +++ b/compiler/test/Dialect/Mhlo/transforms/ConvertOpToCustomCall.mlir @@ -61,3 +61,22 @@ func.func @convert_rng_dynamic(%arg0: tensor) -> tensor { // CHECK-NEXT: call @NextOffsetFunc // CHECK-NEXT: mhlo.custom_call // CHECK-SAME: @byteir.rng_uniform + +// ----- + +func.func @flash_attn_fwd(%arg0: tensor<2x256x12x128xf16>, %arg1: tensor<2x256x12x128xf16>, %arg2: tensor<2x256x12x128xf16>) -> (tensor<2x256x12x128xf16>, tensor<2x12x256xf32>, tensor<2x12x256x256xf16>, tensor<2xi64>) { + %0:4 = mhlo.custom_call @byteir.flash_attn_fwd(%arg0, %arg1, %arg2) {backend_config = "", byteir_attrs = {casual = false, dropout_p = 1.000000e-01 : f64, return_softmax = false, softmax_scale = 1.000000e+00 : f64}} : (tensor<2x256x12x128xf16>, tensor<2x256x12x128xf16>, tensor<2x256x12x128xf16>) -> (tensor<2x256x12x128xf16>, tensor<2x12x256xf32>, tensor<2x12x256x256xf16>, tensor<2xi64>) + return %0#0, %0#1, %0#2, %0#3 : tensor<2x256x12x128xf16>, tensor<2x12x256xf32>, tensor<2x12x256x256xf16>, tensor<2xi64> +} + +// CHECK-LABEL: func.func private @NextOffsetFunc() -> tensor attributes {byre_compute_name = "NextOffset", byre_force_compute_name} +// CHECK-LABEL: func.func private @GetSeedFunc() -> tensor attributes {byre_compute_name = "GetSeed", byre_force_compute_name} +// CHECK-LABEL: func.func @flash_attn_fwd +// CHECK-NEXT: call @GetSeedFunc +// CHECK-NEXT: call @NextOffsetFunc +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: mhlo.concatenate +// CHECK-NEXT: mhlo.custom_call +// CHECK-SAME: @byteir.flash_attn_fwd +// CHECK-SAME: byteir_attrs = {casual = false, dropout_p = 1.000000e-01 : f64, return_softmax = false, softmax_scale = 1.000000e+00 : f64} diff --git a/compiler/test/Dialect/Mhlo/transforms/hloMoveDown.mlir b/compiler/test/Dialect/Mhlo/transforms/hloMoveDown.mlir index 3a46a008b..aa83789d5 100644 --- a/compiler/test/Dialect/Mhlo/transforms/hloMoveDown.mlir +++ b/compiler/test/Dialect/Mhlo/transforms/hloMoveDown.mlir @@ -34,6 +34,18 @@ func.func @transpose_move_down_binary_splat_const(%arg0 : tensor<31x20x32xf32>) // CHECK-NEXT: mhlo.transpose // CHECK-NEXT: return +func.func @transpose_move_down_binary_dense_const(%arg0 : tensor<3x2xf32>) -> tensor<2x3xf32> { + %0 = mhlo.constant dense<[[1.0,2.0,3.0],[4.0,5.0,6.0]]> : tensor<2x3xf32> + %1 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<3x2xf32>) -> tensor<2x3xf32> + %2 = mhlo.add %1, %0 : tensor<2x3xf32> + return %2 : tensor<2x3xf32> +} +// CHECK-LABEL: func.func @transpose_move_down_binary_dense_const +// CHECK-NEXT: mhlo.constant {{.*}} tensor<3x2xf32> +// CHECK-NEXT: mhlo.add +// CHECK-NEXT: mhlo.transpose +// CHECK-NEXT: return + func.func @transpose_move_down_unary_and_cancel(%arg0 : tensor<31x20x32xf32>) -> tensor<31x20x32xf32> { %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<31x20x32xf32>) -> tensor<20x31x32xf32> %1 = "mhlo.abs"(%0) : (tensor<20x31x32xf32>) -> tensor<20x31x32xf32> @@ -76,19 +88,17 @@ func.func @transpose_move_down_two_unary(%arg0 : tensor<31x20x32xf32>) -> tensor // CHECK-NEXT: return // MULTIUSER-LABEL: func.func @transpose_move_down_two_unary -// MULTIUSER-DAG{ABS}: mhlo.abs -// MULTIUSER-NEXT{ABS}: mhlo.transpose -// MULTIUSER-DAG{SINE}: mhlo.sine -// MULTIUSER-NEXT{SINE}: mhlo.transpose +// MULTIUSER-DAG: mhlo.abs +// MULTIUSER-DAG: mhlo.sine // MULTIUSER: mhlo.add +// MULTIUSER-NEXT: mhlo.transpose // MULTIUSER-NEXT: return // AllMULTIUSER-LABEL: func.func @transpose_move_down_two_unary -// AllMULTIUSER-DAG{ABS}: mhlo.abs -// AllMULTIUSER-NEXT{ABS}: mhlo.transpose -// AllMULTIUSER-DAG{SINE}: mhlo.sine -// AllMULTIUSER-NEXT{SINE}: mhlo.transpose +// AllMULTIUSER-DAG: mhlo.abs +// AllMULTIUSER-DAG: mhlo.sine // AllMULTIUSER: mhlo.add +// AllMULTIUSER-NEXT: mhlo.transpose // AllMULTIUSER-NEXT: return func.func @transpose_move_down_1_unary_1_invalid(%arg0 : tensor<31x20x32xf32>, %arg1 : tensor<20x31x32xf32>)-> tensor<20x31x32xf32> { diff --git a/compiler/test/Dialect/Tensor/canonicalizeExt.mlir b/compiler/test/Dialect/Tensor/canonicalizeExt.mlir index 83b828455..b0042aa23 100644 --- a/compiler/test/Dialect/Tensor/canonicalizeExt.mlir +++ b/compiler/test/Dialect/Tensor/canonicalizeExt.mlir @@ -34,3 +34,14 @@ func.func @extract_slice_and_collapse_shape_no_fold(%arg0: tensor<19x1024x1xi32> // CHECK: tensor.extract_slice // CHECK: tensor.collapse_shape +// ---- + +func.func @fold_zero_rank_from_elements_insert_slice(%arg0: tensor<1024xf32>, %scalar : f32) -> tensor<1024xf32> { + %0 = tensor.from_elements %scalar : tensor + %1 = tensor.insert_slice %0 into %arg0[256] [1] [1] : tensor into tensor<1024xf32> + return %1 : tensor<1024xf32> +} +// CHECK-LABEL: fold_zero_rank_from_elements_insert_slice +// CHECK: tensor.insert +// CHECK-NOT: tensor.from_elements +// CHECK-NOT: tensor.insert_slice \ No newline at end of file diff --git a/compiler/test/Transforms/canonicalizeExt.mlir b/compiler/test/Transforms/canonicalizeExt.mlir index abb11d500..a7dbb37cb 100644 --- a/compiler/test/Transforms/canonicalizeExt.mlir +++ b/compiler/test/Transforms/canonicalizeExt.mlir @@ -343,3 +343,40 @@ func.func @transpose_reshape_transpose(%arg0: tensor<2x32x128x256xf16>) -> (tens // CHECK-NEXT: mhlo.reshape // CHECK-NEXT: mhlo.reshape // CHECK-NEXT: return + +func.func @replace_gather_with_input_0() -> (tensor<1x64x128xf16>, tensor<1x32x64x128xf16>) { + %0 = mhlo.constant dense<1.000000e+00> : tensor<64x128xf16> + %1 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64xi64> + %2 = "mhlo.gather"(%0, %1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 128]> : tensor<2xi64>} : (tensor<64x128xf16>, tensor<64xi64>) -> tensor<64x128xf16> + %3 = mhlo.reshape %2 : (tensor<64x128xf16>) -> tensor<1x64x128xf16> + %4 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<64x128xf16>) -> tensor<1x32x64x128xf16> + return %3, %4 : tensor<1x64x128xf16>, tensor<1x32x64x128xf16> +} +// CHECK-LABEL: @replace_gather_with_input_0 +// CHECK-NEXT: mhlo.constant +// CHECK-NEXT: mhlo.constant +// CHECK-NEXT: return + +func.func @replace_gather_with_input_1(%arg0: tensor<64x128xf16>) -> (tensor<1x64x128xf16>, tensor<1x32x64x128xf16>) { + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64xi64> + %1 = "mhlo.gather"(%arg0, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 128]> : tensor<2xi64>} : (tensor<64x128xf16>, tensor<64xi64>) -> tensor<64x128xf16> + %2 = mhlo.reshape %1 : (tensor<64x128xf16>) -> tensor<1x64x128xf16> + %3 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<64x128xf16>) -> tensor<1x32x64x128xf16> + return %2, %3 : tensor<1x64x128xf16>, tensor<1x32x64x128xf16> +} +// CHECK-LABEL: @replace_gather_with_input_1 +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: mhlo.broadcast_in_dim +// CHECK-NEXT: return + +func.func @replace_gather_with_input_2(%arg0: tensor<64x128xf16>) -> (tensor<1x64x128xf16>, tensor<1x32x64x128xf16>) { + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<128xi64> + %1 = "mhlo.gather"(%arg0, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[64, 1]> : tensor<2xi64>} : (tensor<64x128xf16>, tensor<128xi64>) -> tensor<64x128xf16> + %2 = mhlo.reshape %1 : (tensor<64x128xf16>) -> tensor<1x64x128xf16> + %3 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<64x128xf16>) -> tensor<1x32x64x128xf16> + return %2, %3 : tensor<1x64x128xf16>, tensor<1x32x64x128xf16> +} +// CHECK-LABEL: @replace_gather_with_input_2 +// CHECK-NEXT: mhlo.reshape +// CHECK-NEXT: mhlo.broadcast_in_dim +// CHECK-NEXT: return diff --git a/compiler/test/Transforms/memoryPlanning.mlir b/compiler/test/Transforms/memoryPlanning.mlir index 1edca2483..bb9abd835 100644 --- a/compiler/test/Transforms/memoryPlanning.mlir +++ b/compiler/test/Transforms/memoryPlanning.mlir @@ -1,7 +1,7 @@ -// RUN: byteir-opt %s -memory-planning --canonicalize --cse | FileCheck %s -// RUN: byteir-opt %s -memory-planning="alignment=64" --canonicalize --cse | byteir-stat --alloc-cnt | FileCheck %s --check-prefix CHECK-STAT -// RUN: byteir-opt %s -memory-planning="alloca" --canonicalize --cse | FileCheck %s --check-prefix CHECK-ALLOCA -// RUN: byteir-opt %s -memory-planning="alloca mem-space=2" --canonicalize --cse | FileCheck %s --check-prefix CHECK-SPACE +// RUN: byteir-opt %s --pass-pipeline='builtin.module(func.func(memory-planning,canonicalize,cse))' | FileCheck %s +// RUN: byteir-opt %s --pass-pipeline='builtin.module(func.func(memory-planning{alignment=64},canonicalize,cse))' | byteir-stat --alloc-cnt | FileCheck %s --check-prefix CHECK-STAT +// RUN: byteir-opt %s --pass-pipeline='builtin.module(func.func(memory-planning{alloca},canonicalize,cse))' | FileCheck %s --check-prefix CHECK-ALLOCA +// RUN: byteir-opt %s --pass-pipeline='builtin.module(func.func(memory-planning{alloca mem-space=2},canonicalize,cse))' | FileCheck %s --check-prefix CHECK-SPACE func.func @test_basic_reuse(%arg0 : memref<256xf32>, %arg1 : memref<256xf32>) -> memref<256xf32> attributes {__placeholder__byre.entry_point} { %0 = memref.alloc() : memref<256xf32> @@ -203,9 +203,9 @@ func.func @test_reuse_sub_chunk_i1(%arg0 : memref<512xi1>, %arg1 : memref<512xi1 func.func @test_reuse_single_memory_space(%arg0 : memref<512xf32, 1>, %arg1 : memref<512xf32, 2>) { %0 = memref.alloc() : memref<512xf32, 1> - %1 = memref.alloc() : memref<512xf32, 2> + %1 = memref.alloca() : memref<512xf32, 2> %2 = memref.alloc() : memref<512xf32, 1> - %3 = memref.alloc() : memref<512xf32, 2> + %3 = memref.alloca() : memref<512xf32, 2> "lmhlo.add"(%arg0, %arg0, %0) : (memref<512xf32, 1>, memref<512xf32, 1>, memref<512xf32, 1>) -> () "lmhlo.add"(%arg1, %arg1, %1) : (memref<512xf32, 2>, memref<512xf32, 2>, memref<512xf32, 2>) -> () "lmhlo.add"(%0, %0, %arg0) : (memref<512xf32, 1>, memref<512xf32, 1>, memref<512xf32, 1>) -> () diff --git a/compiler/tools/byteir-opt/CMakeLists.txt b/compiler/tools/byteir-opt/CMakeLists.txt index 9c13a778e..f8b667c21 100644 --- a/compiler/tools/byteir-opt/CMakeLists.txt +++ b/compiler/tools/byteir-opt/CMakeLists.txt @@ -8,6 +8,7 @@ set(BYTEIR_LIBS MLIRCclTransformOps ByteIRAffinePasses ByteIRByrePasses + ByteIRGPUPasses ByteIRGPUPipelines ByteIRHostPipelines ByteIRLinalgPasses diff --git a/compiler/tools/byteir-opt/byteir-opt.cpp b/compiler/tools/byteir-opt/byteir-opt.cpp index 8cb60be72..ed3c15e51 100644 --- a/compiler/tools/byteir-opt/byteir-opt.cpp +++ b/compiler/tools/byteir-opt/byteir-opt.cpp @@ -21,6 +21,7 @@ #include "byteir/Dialect/Ccl/IR/CclOps.h" #include "byteir/Dialect/Ccl/Passes.h" #include "byteir/Dialect/Ccl/TransformOps/CclTransformOps.h" +#include "byteir/Dialect/GPU/Passes.h" #include "byteir/Dialect/Lace/LaceDialect.h" #include "byteir/Dialect/Linalg/IR/LinalgExtOps.h" #include "byteir/Dialect/Linalg/Passes.h" @@ -30,6 +31,7 @@ #include "byteir/Dialect/Shape/IR/ShapeExtOps.h" #include "byteir/Dialect/Shape/Passes.h" #include "byteir/Dialect/Tensor/IR/TilingInterfaceImpl.h" +#include "byteir/Dialect/Tensor/Passes.h" #include "byteir/Dialect/Transform/IR/TransformExtOps.h" #include "byteir/Dialect/Transform/Passes.h" #include "byteir/Dialect/Vector/Transforms/Passes.h" @@ -123,11 +125,13 @@ int main(int argc, char **argv) { registerByteIRAffinePasses(); registerByteIRByrePasses(); registerByteIRCclPasses(); + registerByteIRGPUPasses(); registerByteIRLinalgPasses(); registerByteIRMemRefPasses(); registerByteIRMhloPassesExt(); registerByteIRSCFPasses(); registerByteIRShapePasses(); + registerByteIRTensorPasses(); registerByteIRTransformPasses(); registerByteIRVectorPasses(); diff --git a/external/patches/AITemplate/logging.patch b/external/patches/AITemplate/logging.patch new file mode 100644 index 000000000..a5dec1b54 --- /dev/null +++ b/external/patches/AITemplate/logging.patch @@ -0,0 +1,17 @@ +diff --git a/static/csrc/model_container.cpp b/static/csrc/model_container.cpp +index 5548a97..920ed60 100644 +--- a/static/csrc/model_container.cpp ++++ b/static/csrc/model_container.cpp +@@ -80,9 +80,9 @@ ModelContainer::ModelContainer( + useDebugLogging = true; + } + } +- LOG(INFO) +- << (useDebugLogging ? PrintDebugDeviceProperties(prop) +- : PrintInfoDeviceProperties(prop)); ++ //LOG(INFO) ++ // << (useDebugLogging ? PrintDebugDeviceProperties(prop) ++ // : PrintInfoDeviceProperties(prop)); + + LOG(INFO) << "Init AITemplate Runtime with " << num_models << " concurrency"; + models_.reserve(num_models); diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp index dfda80cc2..a8918c43c 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp @@ -188,6 +188,7 @@ Value createL2NormWithoutEps(PatternRewriter &rewriter, Location loc, mhlo::CustomCallSchedule::NONE, nullptr, nullptr, rewriter.getArrayAttr(llvm::ArrayRef{})); DictionaryAttrWrapper attrs(rewriter.getContext()); + attrs.setAttr("epsilon", rewriter.getF64FloatAttr(0.0)); attrs.setAttr("axis", rewriter.getI64ArrayAttr({axis})); customCallOp->setAttr(BYTEIR_ATTRS, getCleanAttr(attrs)); diff --git a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir index d0d986dc9..0e87cd743 100644 --- a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir +++ b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir @@ -175,7 +175,7 @@ func.func @test_l2_norm_pat2(%1146: tensor<12x128xf32>) -> tensor<12x128xf32> { return %1148 : tensor<12x128xf32> // CHECK-LABEL: @test_l2_norm_pat2 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x128xf32>) -> tensor<12x128xf32> { -// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.custom_call @byteir.l2_norm(%arg0) {backend_config = "", byteir_attrs = {axis = [1]}} : (tensor<12x128xf32>) -> tensor<12x128xf32> +// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.custom_call @byteir.l2_norm(%arg0) {backend_config = "", byteir_attrs = {axis = [1], epsilon = 0.000000e+00 : f64}} : (tensor<12x128xf32>) -> tensor<12x128xf32> // CHECK-NEXT: return [[VAR_0_]] : tensor<12x128xf32> } diff --git a/frontends/torch-frontend/examples/demo/README.md b/frontends/torch-frontend/examples/demo/README.md new file mode 100644 index 000000000..65ce80a6a --- /dev/null +++ b/frontends/torch-frontend/examples/demo/README.md @@ -0,0 +1,16 @@ +# ByteIR GPU Compiler for LLM on Torch 2.0 + +### Steps to run +1. Build docker image with [Dockerfile](../../../../docker/Dockerfile). +2. Download ByteIR release and unzip it. +3. Install ByteIR components: + * python3 -m pip install -r ByteIR/requirements.txt + * python3 -m pip install ByteIR/*.whl +4. Run training demo: + * python3 main.py \ <--flash> + * **model-name:** ["gpt2", "bloom-560m", "llama", "opt-1.3b", "nanogpt"] + * **--flash:** means enable flash attention +5. Run inference demo: + * python3 main.py \ --infer <--flash> + * **model-name:** ["llama"] + * **--flash:** means enable flash attention diff --git a/frontends/torch-frontend/examples/demo/backend.py b/frontends/torch-frontend/examples/demo/backend.py new file mode 100644 index 000000000..2e67b84d4 --- /dev/null +++ b/frontends/torch-frontend/examples/demo/backend.py @@ -0,0 +1,195 @@ +import os +import torch +import functools + +from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode +from torch.fx.passes.fake_tensor_prop import FakeTensorProp + +import brt +import byteir + +from torch_frontend import compile +from torch_frontend import list_decomposed_ops, preprocess_fx_graph, fx_replace_attn_pattern, replace_flash_attn, get_none_indices + + +TRACE = False + +cnt = 0 +MODEL_NAME = '' +FLASH = False + + +from functorch.compile import aot_module +from torch._decomp import get_decompositions + +from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete + +class ByteIRInferenceFunction: + def __init__(self, module_path): + self._session = brt.Session(alloc_func=caching_allocator_alloc, + free_func=caching_allocator_delete) + self._session.load(module_path) + self._req = self._session.new_request_context( + torch.cuda.current_stream()._as_parameter_.value) + + def __call__(self, *inputs): + device = inputs[0].device + from brt.utils import brt_dtype_to_torch_dtype + results = [torch.empty(self._session.get_static_shape(offset), + dtype=brt_dtype_to_torch_dtype(self._session.get_data_type(offset)), + device=device) for offset in self._session.get_output_arg_offsets()] + + for offset, input in zip(self._session.get_input_arg_offsets(), inputs): + self._req.bind_arg(offset, input.data_ptr()) + for offset, output in zip(self._session.get_output_arg_offsets(), results): + self._req.bind_arg(offset, output.data_ptr()) + self._req.finish_io_binding() + self._req.run() + self._req.sync() + return results + +class ByteIRFunction: + def __init__(self, module_path, output_shapes, output_dtypes, none_indices): + self._session = brt.Session( + alloc_func=caching_allocator_alloc, + free_func=caching_allocator_delete) + self._session.load(module_path) + self._output_shapes = output_shapes + self._output_dtypes = output_dtypes + self._req = self._session.new_request_context( + torch.cuda.current_stream()._as_parameter_.value) + self._none_indices = none_indices + + def __call__(self, *inputs): + if TRACE: + for i in range(len(inputs)): + input = inputs[i] + print("In ByteIRFunction, Inputs["+str(i)+"]", input) + + device = inputs[0].device + rets = [torch.empty(shape, dtype=dtype, device=device) + for shape, dtype in zip(self._output_shapes, self._output_dtypes)] + for offset, arg in zip(self._session.get_input_arg_offsets(), inputs): + assert list(self._session.get_static_shape(offset)) == list(arg.shape) + for offset, ret in zip(self._session.get_output_arg_offsets(), rets): + assert list(self._session.get_static_shape(offset)) == list(ret.shape) + + for i, tensor in zip(self._session.get_input_arg_offsets(), inputs): + self._req.bind_arg(i, tensor.data_ptr()) + for i, tensor in zip(self._session.get_output_arg_offsets(), rets): + self._req.bind_arg(i, tensor.data_ptr()) + self._req.finish_io_binding() + self._req.run() + self._req.sync() + + if TRACE: + for i in range(len(rets)): + r = rets[i] + print("In ByteIRFunction, Outputs["+str(i)+"]", r) + + # add None results to return values + results = [] + none_ptr = 0 + ret_ptr = 0 + for i in range(len(rets) + len(self._none_indices)): + if none_ptr < len(self._none_indices) and i == self._none_indices[none_ptr]: + results.append(None) + none_ptr += 1 + else: + results.append(rets[ret_ptr]) + ret_ptr += 1 + return results + +def byteir_compile_fx_inner(graph: torch.fx.GraphModule, inputs, is_backward, ban_lst=[]): + category = 'backward' if is_backward else 'forward' + + print("\n\n============") + print(f"{category} Part") + print("============\n\n") + none_indices = get_none_indices(graph) + fx_graph = preprocess_fx_graph(graph) + + compile_type = 'mhlo' + backend_legal_ops = [ + "aten._softmax", + "aten.softmax.int", + "aten.log_softmax.int", + "aten._log_softmax", + # "aten.native_layer_norm", + # "aten.layer_norm", + "aten.gelu", + "aten.argmax", + "aten.max.dim", + "aten.one_hot", + "aten.topk", + "byteir.flash_attn_fwd", + "byteir.flash_attn_bwd", + ] + with maybe_disable_fake_tensor_mode(): + compiled_graph = compile(fx_graph, inputs, compile_type, backend_legal_ops=backend_legal_ops) + + model_name = MODEL_NAME + global cnt + TEMP_FOLDER="./temp" + os.makedirs(TEMP_FOLDER, exist_ok=True) + os.makedirs(TEMP_FOLDER + f"/{model_name}_{category}", exist_ok=True) + mlir_file_name = f'{TEMP_FOLDER}/{model_name}_{category}_{cnt}.{compile_type}.mlir' + output_mlir_file_name = f'{TEMP_FOLDER}/{model_name}_{category}/{model_name}_{category}.rt.mlir' + cnt = cnt + 1 + with open(mlir_file_name, "w+") as fout: + compiled_graph.operation.print(file=fout, + large_elements_limit=None) + + with maybe_disable_fake_tensor_mode(): + byteir.compile(mlir_file_name, output_mlir_file_name, entry_func='forward', target='cuda_with_ait') + + outputs = FakeTensorProp(graph).propagate(*inputs) + mhlo_ret_dtypes = [t.dtype for t in outputs] + mhlo_ret_shapes = [t.shape for t in outputs] + + print(output_mlir_file_name) + return ByteIRFunction(output_mlir_file_name, mhlo_ret_shapes, mhlo_ret_dtypes, none_indices) + + +from torch._inductor.virtualized import V +from torch._dynamo.utils import detect_fake_mode +from torch._dynamo.backends.common import aot_autograd +from torch._inductor.fx_passes.joint_graph import joint_graph_passes + + +def fuse_aware_byteir_compile_fx(model_: torch.fx.GraphModule, example_inputs_): + from partitioners import fuse_aware_min_cut_partition + # TODO: can add logging before/after the call to create_aot_dispatcher_function + # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func + # once torchdynamo is merged into pytorch + fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode( + allow_non_fake_inputs=True + ) + tracing_context = ( + torch._guards.TracingContext.get() or torch._guards.TracingContext(fake_mode) + ) + decompose_list = list_decomposed_ops() + decompositions = get_decompositions(decompose_list) + + def partition_fn(graph, joint_inputs, **kwargs): + joint_graph_passes(graph) + return fuse_aware_min_cut_partition( + graph, joint_inputs, **kwargs, compiler="inductor" + ) + + if FLASH: + # preprocess flash attention + # replace attention pattern to scaled_dot_product_attention + model_ = fx_replace_attn_pattern(model_) + # replace scaled_dot_product_attention to byteir.flash_attn + model_ = replace_flash_attn(model_) + + with V.set_fake_mode(fake_mode), torch._guards.tracing(tracing_context): + return aot_autograd( + fw_compiler=functools.partial(byteir_compile_fx_inner, is_backward=False), + bw_compiler=functools.partial(byteir_compile_fx_inner, is_backward=True), + inference_compiler=functools.partial(byteir_compile_fx_inner, is_backward=False), + decompositions=decompositions, + partition_fn=partition_fn, + keep_inference_input_mutations=True, + )(model_, example_inputs_) diff --git a/frontends/torch-frontend/examples/demo/byteir_fusible_pattern.py b/frontends/torch-frontend/examples/demo/byteir_fusible_pattern.py new file mode 100644 index 000000000..1629f39f5 --- /dev/null +++ b/frontends/torch-frontend/examples/demo/byteir_fusible_pattern.py @@ -0,0 +1,194 @@ +import torch +import torch.fx as fx + +from compile_utils import get_aten_target +from fx_match_utils import get_node_consumer, match_chain + +byteir_fusible_patterns = {} +aten = torch.ops.aten + +def register_byteir_pattern(name): + def register(pattern): + if name in byteir_fusible_patterns.keys(): + raise ValueError("Pattern " + name + " has already been registerd.") + byteir_fusible_patterns[name] = pattern + return pattern + return register + + +class ByteIRFusiblePattern: + + @classmethod + def match(cls, node, required_fw_nodes) -> bool: + raise NotImplementedError + + @classmethod + def get_pattern_recompute_nodes(cls, node, required_fw_nodes): + raise NotImplementedError + +@register_byteir_pattern("transpose_dot") +class TransposeDotPattern(ByteIRFusiblePattern): + + @classmethod + def match(cls, node, required_fw_nodes) -> bool: + post_fusible_ops = [aten.mm, aten.bmm] + if get_aten_target(node) in [aten.t, aten.transpose]: + can_fuse = all(get_aten_target(user) in post_fusible_ops for user in node.users) + all_fw_node = all(user in required_fw_nodes for user in node.users) + return (not all_fw_node) and can_fuse + return False + + @classmethod + def get_pattern_recompute_nodes(cls, node, required_fw_nodes): + if cls.match(node, required_fw_nodes): + return [node] + return [] + + +@register_byteir_pattern("transpose_reshape_transpose_dot") +class TransposeReshapeTransposeDotPattern(ByteIRFusiblePattern): + + @classmethod + def match(cls, node, required_fw_nodes) -> bool: + post_fusible_ops = [aten.mm, aten.bmm, aten.transpose] + if get_aten_target(node) not in [aten.transpose]: + return False + if match_chain(node, target_chain=(aten.transpose, aten.expand, aten.clone, aten._unsafe_view)): + expand_node = get_node_consumer(node, 0) + clone_node = get_node_consumer(expand_node, 0) + view_node = get_node_consumer(clone_node, 0) + all_fw_node = all(user in required_fw_nodes for user in view_node.users) + can_fuse = all(get_aten_target(user) in post_fusible_ops for user in view_node.users) + return (not all_fw_node) and can_fuse + return False + + + @classmethod + def get_pattern_recompute_nodes(cls, node, required_fw_nodes): + if cls.match(node, required_fw_nodes): + expand_node = get_node_consumer(node, 0) + clone_node = get_node_consumer(expand_node, 0) + view_node = get_node_consumer(clone_node, 0) + recompute_nodes = [node, expand_node, clone_node, view_node] + for user in view_node.users: + if user not in required_fw_nodes: + recompute_nodes.append(user) + return recompute_nodes + return [] + +@register_byteir_pattern("transpose_transpose") +class TransposeTransposePattern(ByteIRFusiblePattern): + + @classmethod + def match(cls, node, required_fw_nodes) -> bool: + if get_aten_target(node) in [aten.t, aten.transpose]: + for user in node.users: + if get_aten_target(user) in [aten.t, aten.transpose]: + all_fw_node = all(n in required_fw_nodes for n in user.users) + if not all_fw_node: + return True + return False + + + @classmethod + def get_pattern_recompute_nodes(cls, node, required_fw_nodes): + if cls.match(node, required_fw_nodes): + recompute_nodes = [node] + for user in node.users: + if get_aten_target(user) == aten.t: + recompute_nodes.append(user) + return recompute_nodes + return [] + + +@register_byteir_pattern("full_bitwise_not_expand") +class FullBitwiseNotExpandPattern(ByteIRFusiblePattern): + + @classmethod + def match(cls, node, required_fw_nodes) -> bool: + if match_chain(node, target_chain=(aten.full, aten.bitwise_not, aten.expand)): + return True + return False + + + @classmethod + def get_pattern_recompute_nodes(cls, node, required_fw_nodes): + if cls.match(node, required_fw_nodes): + bitwise_node = get_node_consumer(node, 0) + expand_node = get_node_consumer(bitwise_node, 0) + recompute_nodes = [node, bitwise_node, expand_node] + return recompute_nodes + return [] + + +# Note: This pattern is temporary. +# It is only used to fix issue that full op(dtype is bool) is not supported in byteir. +@register_byteir_pattern("copy_bitwise_not_expand") +class CopyBitwiseNotExpandPattern(ByteIRFusiblePattern): + + @classmethod + def match(cls, node, required_fw_nodes) -> bool: + if match_chain(node, target_chain=(aten._to_copy, aten.bitwise_not, aten.expand, aten.bitwise_or)): + bitwise_not_node = get_node_consumer(node, 0) + expand_node = get_node_consumer(bitwise_not_node, 0) + bitwise_or_node = get_node_consumer(expand_node, 0) + return True + return False + + + @classmethod + def get_pattern_recompute_nodes(cls, node, required_fw_nodes): + if cls.match(node, required_fw_nodes): + bitwise_not = get_node_consumer(node, 0) + expand = get_node_consumer(bitwise_not, 0) + bitwise_or = get_node_consumer(expand, 0) + recompute_nodes = [node, bitwise_not, expand, bitwise_or] + return recompute_nodes + return [] + + +def greedy_transpose_fusion(joint_graph, required_fw_nodes): + recompute_nodes = [] + post_fuse_ops = [aten.bmm, aten.mm] + transparent_ops = [aten.clone, aten._to_copy, aten.expand] + view_ops = [aten.view, aten._unsafe_view] + transpose_ops = [aten.t, aten.transpose] + fusible_tag = {} + + INIT_TAG = 0 + POST_FUSION_TAG = 1 + TRANSPOSE_TAG = 2 + + + for node in reversed(joint_graph.nodes): + fusible_tag[node] = INIT_TAG + + for node in reversed(joint_graph.nodes): + if get_aten_target(node) in post_fuse_ops and node not in required_fw_nodes: + fusible_tag[node] = POST_FUSION_TAG + + if get_aten_target(node) in transparent_ops: + for user in node.users: + if user in fusible_tag.keys() and fusible_tag[user] >= POST_FUSION_TAG: + fusible_tag[node] = POST_FUSION_TAG + recompute_nodes.append(node) + + if get_aten_target(node) in transpose_ops: + for user in node.users: + if user in fusible_tag.keys() and fusible_tag[user] >= POST_FUSION_TAG: + recompute_nodes.append(node) + fusible_tag[node] = INIT_TAG + + return recompute_nodes + + +def get_byteir_recompute_nodes(joint_graph, required_fw_nodes): + recompute_nodes = [] + recompute_nodes.extend(greedy_transpose_fusion(joint_graph, required_fw_nodes)) + for name, pattern in byteir_fusible_patterns.items(): + for node in joint_graph.nodes: + if node.op == 'output': + continue + recompute_nodes.extend(pattern.get_pattern_recompute_nodes(node, required_fw_nodes)) + recompute_nodes = list(set(recompute_nodes)) + return recompute_nodes diff --git a/frontends/torch-frontend/examples/demo/compile_utils.py b/frontends/torch-frontend/examples/demo/compile_utils.py new file mode 100644 index 000000000..e08df059e --- /dev/null +++ b/frontends/torch-frontend/examples/demo/compile_utils.py @@ -0,0 +1,92 @@ + +import torch +import torch.fx as fx +from torch.utils._pytree import tree_flatten + +aten = torch.ops.aten + + +def get_aten_target(node): + if hasattr(node.target, 'overloadpacket'): + return node.target.overloadpacket + return node.target + + +rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma, + aten.bernoulli, aten.multinomial, aten.native_dropout, + aten.normal, aten.poisson, aten.binomial, aten.rrelu, + aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm] + + +# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph +def fx_graph_cse(fx_g: torch.fx.graph.Graph): + new_graph = fx.Graph() + env = {} # map from node in the old graph to node in the new graph + hash_env = {} # map from hash to a node in the new graph + token_map = {} # map from hash to token + for n in fx_g.nodes: + # The placeholder, output, and get_attr nodes are copied to the new grpah without change + # do not CSE away random operations + if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops: + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, torch.fx.node.Node) and v in env: + arg_list[i] = env[v] + if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)): + arg_list[i] = v.node + return tuple(arg_list), spec + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = {"target": n.target, "args": args, "args_spec": args_spec, + "kwargs": kwargs, "kwargs_spec": kwargs_spec} + + # hash substituted args to a number, do not hash specs because specs are not hashable + hash_arg = hash((args, kwargs)) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + return new_graph + + +def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() + + +def get_placeholders(graph): + return list(filter(lambda x: x.op == 'placeholder', graph.nodes)) + +def get_outputs(graph): + for node in graph.nodes: + if node.op == 'output': + return tree_flatten(node.args[0])[0] + raise AssertionError("No output node found") diff --git a/frontends/torch-frontend/examples/demo/config.py b/frontends/torch-frontend/examples/demo/config.py new file mode 100644 index 000000000..811bce8cc --- /dev/null +++ b/frontends/torch-frontend/examples/demo/config.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Global flags for aot autograd +""" +import os +import sys + +# Converts torch rng ops to their functional philox rng equivalents. Note that +# we functionalize only CUDA rng ops today. +functionalize_rng_ops = False + +# can be useful for debugging if we are incorrectly creating meta fake tensors +fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True) + +# Enables optional asserts in hotpath code to check for errors. If +# you are seeing weird accuracy problems, try turning this on. +# This is currently off by default as it will harm tracing time, +# but it is on by default for aot_eager. +debug_assert = False + +debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False) + +static_weight_shapes = True + +# Applies CSE to the graph before partitioning +cse = True + +# Restricts the amount of computation AOTAutograd can do. +max_dist_from_bw = 3 + diff --git a/frontends/torch-frontend/examples/demo/fx_match_utils.py b/frontends/torch-frontend/examples/demo/fx_match_utils.py new file mode 100644 index 000000000..d7a209c76 --- /dev/null +++ b/frontends/torch-frontend/examples/demo/fx_match_utils.py @@ -0,0 +1,40 @@ +import torch +from compile_utils import get_aten_target + +aten = torch.ops.aten + +def is_used_by_specific_consumer(node, consumer_type=None): + if consumer_type == None: + return True + + all_users = list(node.users) + if len(all_users) != 1: + return False + consumer = all_users[0] + if not isinstance(consumer_type, (list, tuple)): + consumer_type = [consumer_type] + if get_aten_target(consumer) not in consumer_type: + return False + return True + + +def get_node_consumer(node, index): + all_users = list(node.users) + return all_users[index] + + +def match_chain(node, target_chain): + if len(target_chain) == 1: + return get_aten_target(node) in target_chain + + if len(list(node.users)) != 1: + return False + + specific_types = target_chain[0] + + if not isinstance(specific_types, (list, tuple)): + specific_types = [specific_types] + + if get_aten_target(node) in specific_types: + return match_chain(get_node_consumer(node, 0), target_chain[1:]) + return False diff --git a/frontends/torch-frontend/examples/demo/main.py b/frontends/torch-frontend/examples/demo/main.py new file mode 100644 index 000000000..c913ceb5c --- /dev/null +++ b/frontends/torch-frontend/examples/demo/main.py @@ -0,0 +1,220 @@ +from torch import nn +import torch +import transformers +import sys +import os +import functools +import torch._dynamo +import torch.nn.functional as F + +import transformers +import argparse + +MODEL_LIST = ["gpt2", "bloom-560m", "llama", "opt-1.3b", "nanogpt"] + +class InferLLAMAModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = transformers.LlamaConfig(num_hidden_layers=4, return_dict=False) + self.model = transformers.LlamaForCausalLM(config=self.config) + def forward(self, x): + return self.model(x)[0] + +class InferOPTModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = transformers.AutoConfig.from_pretrained("facebook/opt-1.3b", return_dict=False) + self.config.tie_word_embeddings = False + self.model = transformers.OPTForCausalLM(config=self.config) + def forward(self, x): + return self.model(x)[0] + +class InferBLOOMModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = transformers.BloomConfig.from_pretrained('bigscience/bloom-560m', return_dict=False) + self.config.tie_word_embeddings = False + self.model = transformers.BloomForCausalLM(config=self.config) + def forward(self, x): + return self.model(x)[0] + +class InferGPT2Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = transformers.GPT2Config.from_pretrained('gpt2', return_dict=False) + self.config.num_labels = self.config.vocab_size + self.model = transformers.GPT2ForTokenClassification(config=self.config) + def forward(self, x): + return self.model(x)[0] + +def make_model(model_name): + if model_name == 'llama': + config = transformers.LlamaConfig(num_hidden_layers=4) + model = transformers.LlamaForCausalLM(config=config) + elif model_name == 'opt-1.3b': + config = transformers.AutoConfig.from_pretrained("facebook/opt-1.3b") + config.tie_word_embeddings = False + model = transformers.OPTForCausalLM(config=config) + elif model_name == 'bloom-560m': + config = transformers.BloomConfig.from_pretrained('bigscience/bloom-560m') + config.tie_word_embeddings = False + model = transformers.BloomForCausalLM(config=config) + elif model_name == 'gpt2': + config = transformers.GPT2Config.from_pretrained('gpt2') + config.num_labels = config.vocab_size + model = transformers.GPT2ForTokenClassification(config=config) + elif model_name == 'nanogpt': + from my_transformers.modeling_nanogpt import GPTConfig, GPT + config_args = dict(n_layer=12, n_head=12, n_embd=768) + config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints + config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints + config_args['bias'] = True # always True for GPT model checkpoints + # we can override the dropout rate, if desired + config_args['dropout'] = 0. + # create a from-scratch initialized minGPT model + config = GPTConfig(**config_args) + model = GPT(config) + else: + assert False + return model + +def make_inference_model(model_name): + if model_name == 'llama': + return InferLLAMAModule() + elif model_name == 'opt-1.3b': + return InferOPTModule() + elif model_name == 'bloom-560m': + return InferBLOOMModule() + elif model_name == 'gpt2': + return InferGPT2Module() + else: + return make_model(model_name) + +def make_data(model, model_name, device): + batch_size = 8 + if model_name == 'llama': + batch_size = 16 + elif model_name == 'opt-1.3b': + batch_size = 4 + seq_len = 1024 + input = torch.randint( + low=0, high=model.config.vocab_size, size=(batch_size, seq_len), device=device + ) + + label = torch.randint(low=0, high=model.config.vocab_size, size=(batch_size, seq_len), + device=device) + return input, label + +def compute_loss(model, data, model_name): + if model_name == 'nanogpt': + input_idx, output_idx = data + _, loss = model(input_idx, output_idx) + else: + input, label = data + output = model(input) + logits = output.logits + loss = F.cross_entropy(logits.view(-1, model.config.vocab_size), label.view(-1)) + return loss + + +def infer_model(args): + device = torch.device('cuda:' + str(args.device_id)) + model = make_inference_model(args.model_name) + model.eval() + model.to(device) + + input, label = make_data(model, args.model_name, device) + trace_data = [input] + if args.model_name == "nanogpt": + trace_data.append(label) + # torch.save(trace_data, "batch_sample_inputs") + + TEMP_FOLDER="./temp" + os.makedirs(TEMP_FOLDER, exist_ok=True) + os.makedirs(TEMP_FOLDER + f"/{args.model_name}_inference_f16", exist_ok=True) + jit_file_name = TEMP_FOLDER + f"/{args.model_name}_inference.f16.jit" + mhlo_file_name = TEMP_FOLDER + f"/{args.model_name}_inference.f16.mhlo.mlir" + byre_file_name = TEMP_FOLDER + f"/{args.model_name}_inference_f16/{args.model_name}.rt.mlir" + + with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16): + # if not os.path.exists(jit_file_name): + # module = torch.jit.trace(model, trace_data, check_trace=False) + # torch.jit.save(module, jit_file_name) + # print("save jit to {}".format(jit_file_name)) + + if not os.path.exists(mhlo_file_name): + # module = torch.jit.load(jit_file_name) + if args.flash: + from torch.fx.experimental.proxy_tensor import make_fx + from torch_frontend import preprocess_fx_graph + module = make_fx(model)(*trace_data) + print("torch inputs:") + print(trace_data) + print("torch outputs:") + print(module(*trace_data)) + module = preprocess_fx_graph(module) + else: + module = torch.jit.trace(model, trace_data, check_trace=False) + print("torch inputs:") + print(trace_data) + print("torch outputs:") + print(module(*trace_data)) + import torch_frontend + mhlo_model = torch_frontend.compile(module, trace_data, "mhlo") + with open(mhlo_file_name, "w") as f: + print(mhlo_model.operation.get_asm(), file=f) + print("save mhlo to {}".format(mhlo_file_name)) + + if not os.path.exists(byre_file_name): + import byteir + print("begin byteir compile") + byteir.compile(mhlo_file_name, byre_file_name, entry_func='forward', target='cuda_with_ait', disable_byteir_cache=False, verbose=False) + print("byteir compile to {}".format(byre_file_name)) + + from backend import ByteIRInferenceFunction + runner = ByteIRInferenceFunction(byre_file_name) + print("byre inputs:") + print(trace_data) + print("byre outputs:") + print(runner(*trace_data)) + +def train_model(args): + torch._dynamo.reset() + torch._dynamo.disallow_in_graph(F.cross_entropy) + + model_name = args.model_name + use_flash_attn = args.flash + device = torch.device('cuda:' + str(args.device_id)) + model = make_model(model_name) + model.to(device) + + import backend + from backend import fuse_aware_byteir_compile_fx + backend.MODEL_NAME = model_name + backend.FLASH = use_flash_attn + + optimized_model = torch.compile(model, backend=fuse_aware_byteir_compile_fx) + + data = make_data(optimized_model, model_name, device) + model.zero_grad(set_to_none=True) + with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16): + loss = compute_loss(optimized_model, data, model_name) + print("loss:", loss) + loss.backward() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_name") + parser.add_argument("--flash", action="store_true", help="use flash attention when possible") + parser.add_argument("--infer", action="store_true", help="infer mode") + parser.add_argument("--device_id", type=int, default=0) + args = parser.parse_args() + # print(args) + + assert args.model_name in MODEL_LIST + if args.infer: + infer_model(args) + else: + train_model(args) + diff --git a/frontends/torch-frontend/examples/demo/partitioners.py b/frontends/torch-frontend/examples/demo/partitioners.py new file mode 100644 index 000000000..43bea47b9 --- /dev/null +++ b/frontends/torch-frontend/examples/demo/partitioners.py @@ -0,0 +1,940 @@ +from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types +from torch.fx.experimental.symbolic_shapes import ( + hint_int, magic_methods, method_to_operator, free_symbols, + is_symbol_binding_fx_node, find_symbol_binding_fx_nodes +) +import torch +import torch.fx as fx +import operator +import math +import torch.utils._pytree as pytree +import copy +import os +import itertools +import sympy +from collections import defaultdict +from torch.fx.passes import graph_drawer +from typing import Tuple +from compile_utils import fx_graph_cse, get_aten_target +import config +import functools + +from byteir_fusible_pattern import get_byteir_recompute_nodes + +AOT_PARTITIONER_DEBUG = config.debug_partitioner + + +def must_recompute(node): + return node.meta.get("recompute", False) + +def has_recomputable_ops(fx_g): + found = False + for node in fx_g.graph.nodes: + if must_recompute(node): + return True + return False + +def has_recomputable_rng_ops(fx_g): + for node in fx_g.graph.nodes: + if must_recompute(node) and hasattr(node.target, "tags") and torch.Tag.nondeterministic_seeded in node.target.tags: + return True + return False + +def sym_node_size(node): + if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): + return 1 + assert isinstance(node.meta["val"], torch.SymFloat) + return 4 + +class InvalidNodeBase: + def __repr__(self): + return "Invalid Node" + + +InvalidNode = InvalidNodeBase() + + +def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): + """ + Given a graph, extracts out a subgraph that takes the specified nodes as + inputs and returns the specified outputs. + + This includes specifying non-placeholder nodes as inputs. + + The general strategy is to initialize all inputs with proxies as we + encounter them, and trace through the graph, only keeping values which take + in valid proxies. Then, all dead code is eliminated. + """ + new_graph = fx.Graph() + env = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in inputs: + new_node = new_graph.placeholder(node.name) + # Can't use node_copy here as we may be turning previous call_function into placeholders + new_node.meta = node.meta + env[node] = new_node + + for node in joint_graph.nodes: + if node in inputs: + continue + elif node.op == 'placeholder': + env[node] = InvalidNode + elif node.op == 'call_function': + all_args = pytree.tree_flatten((node.args, node.kwargs))[0] + all_args = [isinstance(env[x], InvalidNodeBase) for x in all_args if isinstance(x, fx.Node)] + if any(all_args): + env[node] = InvalidNode + continue + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == 'get_attr': + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == 'output': + pass + output_values = [] + for x in outputs: + if isinstance(x, fx.Node): + if x not in env: + raise RuntimeError(f"Node {x} couldn't be found in env") + assert not isinstance(env[x], InvalidNodeBase), f"Node {x} was invalid, but is output" + output_values.append(env[x]) + else: + output_values.append(x) + new_graph.output(output_values) + + new_graph.eliminate_dead_code() + new_graph.lint() + return new_graph + + +def _is_primal(node): + return ( + node.op == "placeholder" + and "tangents" not in node.target + and not _is_bwd_seed_offset(node) + and not _is_fwd_seed_offset(node) + ) + +def _is_tangent(node): + return node.op == "placeholder" and "tangents" in node.target + +def _is_bwd_seed_offset(node): + return node.op == "placeholder" and ("bwd_seed" in node.target or "bwd_base_offset" in node.target) + +def _is_fwd_seed_offset(node): + return node.op == "placeholder" and ("fwd_seed" in node.target or "fwd_base_offset" in node.target) + + +def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): + outputs = pytree.tree_flatten([node.args for node in joint_module.graph.nodes if node.op == 'output'])[0] + fwd_outputs = outputs[:num_fwd_outputs] + bwd_outputs = outputs[num_fwd_outputs:] + return fwd_outputs, bwd_outputs + + +def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs): + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + bwd_seed_offset_inputs = list(filter(_is_bwd_seed_offset, joint_module.graph.nodes)) + + # Construct the forward module + # Keep symints separate from tensors, passed between fwd/bwd graphs, and in the right order. + fwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + primal_inputs + fwd_seed_offset_inputs, + fwd_outputs + saved_values + saved_sym_nodes + ) + bwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, + bwd_outputs + ) + + # This is to filter out saved values that don't actually end up being used by the backwards pass + for node in bwd_graph.nodes: + if node.op == 'placeholder' and not node.users: + for saved_value in saved_values: + if saved_value.name == node.name: + saved_values.remove(saved_value) + break + + for saved_sym in saved_sym_nodes: + if saved_sym.name == node.name: + saved_sym_nodes.remove(saved_sym) + break + + # Now that we have the finalized list of saved values, we need to ensure + # we propagate all symbols which are referenced by backwards inputs. + # These are not directly used in the graph but are required for downstream + # sizevar assignment + saved_symbols: Set[sympy.Symbol] = set() + saved_sym_nodes_binding = [] + saved_sym_nodes_derived = [] + + # Some symbols may already be bound in the directly saved_sym_nodes, + # keep track of them so we don't re-bind them + for node in saved_sym_nodes: + symbol = is_symbol_binding_fx_node(node) + if symbol: + saved_symbols.add(symbol) + saved_sym_nodes_binding.append(node) + else: + saved_sym_nodes_derived.append(node) + + # Now go through all of the prospective backward inputs and track any + # other symbols we need to bind + symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) + for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs): + if "val" not in node.meta: + continue + new_symbols = free_symbols(node.meta["val"]) - saved_symbols + # NB: Deterministic order please! + for s in sorted(new_symbols, key=lambda s: s.name): + # NB: For well formed graphs, the symbol should always be present, + # but we also have ways to produce ill-formed graphs, e.g., direct + # make_fx usages, so don't choke in this case + if s not in symbol_bindings: + continue + saved_sym_nodes_binding.append(symbol_bindings[s]) + saved_symbols |= new_symbols + + + # Update saved_sym_nodes that are now reordered to have all bindings at + # front. This can also be used later on to figure out the position of saved + # sym nodes in the output of fwd graph. + saved_sym_nodes.clear() + saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) + + # Now, we re-generate the fwd/bwd graphs. + # NB: This might increase compilation time, but I doubt it matters + fwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + primal_inputs + fwd_seed_offset_inputs, + fwd_outputs + saved_values + saved_sym_nodes + ) + bwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, + bwd_outputs + ) + + fwd_module = fx.GraphModule(joint_module, fwd_graph) + bwd_module = fx.GraphModule(joint_module, bwd_graph) + return fwd_module, bwd_module + + +def default_partition( + joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the :attr:`joint_module` in a manner that closely resembles the + behavior observed in the original ``.forward()`` and ``.backward()`` of the + callable, i.e., the resulting forward graph contains those operators that + are executed in the original ``.forward()`` callable passed to + :func:`aot_function`. + + The default partitioner collects the operators that are between the forward + inputs and the forward outputs. This helps in finding the tensors which have + to be stashed for the backward pass. These stashed tensors become the output + of the generated forward graph. The remaining operators are then placed in + the backward graph. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + if has_recomputable_ops(joint_module): + return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs) + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs) + forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != 'output'} + saved_values = [] + saved_sym_nodes = [] + + for node in joint_module.graph.nodes: + if node.name not in forward_node_names: + continue + if is_sym_node(node): + # Symints must be kept separate from tensors so that PythonFunction only calls + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes.append(node) + elif ( + 'tensor_meta' not in node.meta + and node.op == 'call_function' + ): + # Since we can't save tuple of tensor values, we need to flatten out what we're saving + users = node.users + assert all(user.target == operator.getitem for user in users) + for user in users: + saved_values.append(user) + else: + backward_usages = [n for n in node.users if n.name not in forward_node_names] + if 'tensor_meta' in node.meta and all(is_sym_node(n) for n in backward_usages): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + for user in backward_usages: + saved_sym_nodes.append(user) + else: + saved_values.append(node) + saved_values = list({k: None for k in saved_values}.keys()) + saved_sym_nodes = list({k: None for k in saved_sym_nodes}.keys()) + + return _extract_fwd_bwd_modules(joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs) + + +def _prod(x): + s = 1 + for i in x: + s *= i + return s + +def _tensor_nbytes(numel, dtype): + sizes = { + torch.complex64: 8, + torch.complex128: 16, + torch.float16: 2, + torch.bfloat16: 2, + torch.float32: 4, + torch.float64: 8, + torch.int8: 1, + torch.int16: 2, + torch.int32: 4, + torch.int64: 8, + torch.uint8: 1, + torch.bool: 1, + } + if dtype not in sizes: + raise NotImplementedError("Don't know the size of dtype ", dtype) + + return numel * sizes[dtype] + +def _size_of(node: fx.Node) -> int: + if 'val' in node.meta: + val = node.meta['val'] + if isinstance(val, py_sym_types): + if isinstance(val, torch.SymInt): + return 1 + else: + return 999999 + elif isinstance(val, (list, tuple)): + return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor)) + elif isinstance(val, torch.Tensor): + return _tensor_nbytes(hint_int(val.numel()), val.dtype) + + raise RuntimeError(f"Unknown metadata type {type(val)}") + + # Only needed since we don't always trace with fake tensors. + if 'tensor_meta' in node.meta: + metadata = node.meta['tensor_meta'] + numel = _prod(map(to_size_hint, metadata.shape)) + dtype = metadata.dtype + else: + return 0 + + return _tensor_nbytes(numel, dtype) + + +# Used for some investigative purposes +def _count_ops(graph): + from collections import defaultdict + cnt = defaultdict(int) + for node in graph.nodes: + if node.op == 'call_function': + cnt[node.target.__name__] += 1 + print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) + + +@functools.lru_cache(None) +def pointwise_ops(): + ops = [] + for attr_name in dir(torch.ops.aten): + opoverloadpacket = getattr(torch.ops.aten, attr_name) + if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket): + continue + + for overload in opoverloadpacket.overloads(): + op_overload = getattr(opoverloadpacket, overload) + if torch.Tag.pointwise in op_overload.tags: + # currently aot autograd uses packet not overload + ops.append(opoverloadpacket) + break + + return ops + +def get_depth(node, depth_map): + if node in depth_map: + return depth_map[node] + + # Base case + if node.op == "placeholder": + depth_map[node] = 0 + return depth_map[node] + + # Handle output node + if node.op == "output": + args = node.args[0] + for arg in args: + if isinstance(arg, torch.fx.node.Node): + get_depth(arg, depth_map) + return + + # Get the depth of args and set the depth of this node + arg_depths = [get_depth(arg, depth_map) for arg in node.all_input_nodes if isinstance(arg, torch.fx.node.Node)] + # factory ops like full, rand might not have any input args + if len(arg_depths) == 0: + arg_depths = [0] + depth_map[node] = max(arg_depths) + 1 + return depth_map[node] + + +def sort_depths(args, depth_map): + arg_depths = {arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)} + return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) + + +def reordering_to_mimic_autograd_engine(gm): + """ + This pass finds the first bwd node in the graph (by looking at users of + tangents) and then reorders the graph by walking from this node to all the + way to the end of the graph. At each op in this traveral, we insert this op + in a new graph and try to bring only the relevant subgraph from the other + non-bwd edges relevant for this op. This closely mimics the behavior of + autograd engine. + + Why is this pass required in the first place? + + This is an artifact of how partitioners work today. The starting point of + partitioner is a joint graph, which is fwd and then bwd graph. In the case + of checkpointing, we keep portions of fwd graph in their original place in + the joint graph, while obtaining a bwd graph. As a result, the resulting bwd + graph has copies of recomputed fwd subgraphs followed by the original bwd + graph. If we run this naively, this leads to bad memory footprint, because + the fwd subgraphs are live for way longer duration than necessary. This pass + reorders the operations such that we prioritize the ops for the original bwd + graph while only realizing those ops from the fwd graph that are necessary + at any given point in the graph. + """ + + new_graph = fx.Graph() + env = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in gm.graph.nodes: + if node.op == "placeholder": + new_node = new_graph.placeholder(node.name) + # Can't use node_copy here as we may be turning previous call_function into placeholders + new_node.meta = node.meta + env[node] = new_node + + + order = {} + for idx, node in enumerate(gm.graph.nodes): + order[node] = idx + + # Populate depth for the nodes. Depth is the distance from the inputs. + depths = {} + output_node = [node for node in gm.graph.nodes if node.op == "output"][0] + get_depth(output_node, depths) + + def insert_node_in_graph(node): + if node in env: + return env[node] + + # Bias traversal towards the nodes that have higher depth - prioritizes + # critical path first. + for arg, _ in sort_depths(node.all_input_nodes, depths): + env[arg] = insert_node_in_graph(arg) + env[node] = new_graph.node_copy(node, lambda x: env[x]) + return env[node] + + # Find first bwd node in the graph + tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) + first_node_in_bwd = None + minimum_order = math.inf + for tangent in tangent_inputs: + for user in tangent.users: + if order[user] < minimum_order: + minimum_order = order[user] + first_node_in_bwd = user + assert first_node_in_bwd is not None + + # Build the graph op-by-op by starting from the node all the way to the end + for node in list(gm.graph.nodes)[order[first_node_in_bwd]:]: + insert_node_in_graph(node) + + # The output node is already built by the traversal. + new_gm = torch.fx.GraphModule(gm, new_graph) + return new_gm + + +def functionalize_rng_ops(joint_module, fw_module, bw_module, num_sym_nodes): + # During user-driven activation checkpointing, we have to ensure that a rng + # op in fwd yields the same output as the recomputed rng op in the bwd. To + # do this, we use functionalize wrappers to wrap the random ops and share + # rng state between the fwd and bwd graphs. + + # There are 3 main steps to do this + # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. + # Step 2 - Modify the fwd pass such that + # 1) Replace rand with run_and_save_rng_state wrapper + # 2) Replace the users of the original op with the output[1] of this op. + # 3) Collect all the rng_state - output[0] of each op, and make them + # output nodes. Special care needs to be taken here because fwd outputs + # has symints at the very end. + # Step 3 - Modify the bwd pass such that + # 1) Add the input nodes just before the tangents for the stashed rng states + # 2) Replace rand with run_with_save_rng_state wrappers + # 3) Use the stashed states as inputs to these ops + + # Unique id to generate name + uid = itertools.count() + + def get_rng_ops(gmod): + random_nodes = {} + for node in gmod.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + random_nodes[node.name] = node + return random_nodes + + # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. + joint_graph_rng_ops = get_rng_ops(joint_module) + fw_graph_rng_ops = get_rng_ops(fw_module) + bw_graph_rng_ops = get_rng_ops(bw_module) + recomputable_rng_ops_map = dict() + for node in joint_module.graph.nodes: + if ( + must_recompute(node) + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + base_node = joint_graph_rng_ops[node.name] + fw_node = fw_graph_rng_ops[node.name] + bw_node = bw_graph_rng_ops[node.name] + recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node} + + run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state + run_with_rng_state = torch._prims.rng_prims.run_with_rng_state + + for node in bw_module.graph.nodes: + if node.op == "placeholder" and "tangent" in node.name: + bw_tangent_start_node = node + break + + fw_rng_state_outputs = [] + for base_node, node_pair in recomputable_rng_ops_map.items(): + # Step 2 - Modify the fwd pass such that + fw_node = node_pair["fwd"] + bw_node = node_pair["bwd"] + fw_graph = fw_module.graph + with fw_graph.inserting_before(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + run_and_save_rng, + args=(fw_node.target, *fw_node.args), + kwargs=fw_node.kwargs + ) + state = fw_graph.create_node("call_function", operator.getitem, args=(functional_fw_node, 0), kwargs={}) + rng_output = fw_graph.create_node("call_function", operator.getitem, args=(functional_fw_node, 1,), kwargs={}) + fw_node.replace_all_uses_with(rng_output) + fw_graph.erase_node(fw_node) + fw_rng_state_outputs.append(state) + + + # Step 3 - Modify the bwd pass such that + bw_graph = bw_module.graph + with bw_graph.inserting_before(bw_tangent_start_node): + state_name = f"rng_state_output_{next(uid)}" + bw_rng_state_node = bw_graph.placeholder(state_name) + bw_rng_state_node.meta["val"] = torch.cuda.get_rng_state() + + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + run_with_rng_state, + args=(bw_rng_state_node, bw_node.target, *bw_node.args), + kwargs=bw_node.kwargs + ) + + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) + + + # Add the rng states in the output of the fwd graph. AOT Autograd assumes + # that symints are at the end of forward graph outputs. So, insert the new + # rng states accordingly. + fw_output_node = [node for node in fw_module.graph.nodes if node.op == "output"][0] + fw_outputs = fw_output_node.args[0] + sym_node_start_idx = len(fw_outputs) - num_sym_nodes + outputs = fw_outputs[:sym_node_start_idx] + fw_rng_state_outputs + fw_outputs[sym_node_start_idx:] + fw_module.graph.output(outputs) + fw_module.graph.erase_node(fw_output_node) + fw_module.recompile() + bw_module.recompile() + return fw_module, bw_module + + +def cleanup_recompute_tags(joint_module): + """ + If there are two consecutive checkpointed blocks with no operator in + between, we would still want to stash the tensor at the boundary of + checkpointed blocks. The following pass makes the last output node + non-recomputable to allow for that. + """ + for node in joint_module.graph.nodes: + if must_recompute(node): + for user in node.users: + if must_recompute(user) and user.meta["recompute"] > node.meta["recompute"]: + node.meta["recompute"] = 0 + return joint_module + + +def fuse_aware_min_cut_partition( + joint_module: fx.GraphModule, _joint_inputs, compiler="inductor", recomputable_ops=None, + *, num_fwd_outputs +) -> Tuple[fx.GraphModule, fx.GraphModule]: + print("num_fwd_outputs : ", num_fwd_outputs) + """ + Partitions the joint graph such that the backward recomputes the forward. + Recomputing helps in trading off memory bandwidth with computation. + + To create the fwd and bwd graph, we copy the joint graph, manually set the + outputs to just original forward or backward outputs. And then we run the + resulting graphs through dead code elimintation. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + _joint_inputs: The inputs to the joint graph. This is unused. + compiler: This option determines the default set of recomputable ops. + Currently, there are two options: ``nvfuser`` and ``inductor``. + recomputable_ops: This is an optional set of recomputable ops. If this + is not None, then this set of ops will be used instead of the + default set of ops. + num_fwd_outputs: The number of outputs from the forward graph. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + try: + import networkx as nx + except ImportError as e: + raise RuntimeError("Need networkx installed to perform smart recomputation " + "heuristics") from e + + joint_module.graph.eliminate_dead_code() + joint_module.recompile() + + fx_g = joint_module.graph + + # add the CSE pass + if config.cse: + cse_graph = fx_graph_cse(fx_g) + joint_module.graph = cse_graph + full_bw_graph = joint_module.graph + + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + joint_module = cleanup_recompute_tags(joint_module) + + name_to_node = {} + for node in joint_module.graph.nodes: + name_to_node[node.name] = node + + def classify_nodes(joint_module): + required_bw_nodes = set() + for node in joint_module.graph.nodes: + if node.op == 'placeholder' and "tangents" in node.target: + required_bw_nodes.add(node) + if node in required_bw_nodes: + for user in node.users: + required_bw_nodes.add(user) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + required_bw_nodes.update(o for o in bwd_outputs if o is not None) + forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs) + required_fw_nodes = {name_to_node[node.name] for node in forward_only_graph.nodes + if node.op != 'output'} + unclaimed_nodes = {node for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes} + return fwd_outputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes + + orig_fw_outputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module) + + # networkx blows up on graphs with no required backward nodes + # Since there's nothing to partition anyway, and the default partitioner can "handle" + # this case, send our graph over to the default partitioner. + if len(required_bw_nodes) == 0: + return default_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs) + + for node in reversed(joint_module.graph.nodes): + if node not in required_fw_nodes: + node.dist_from_bw = 0 + else: + node.dist_from_bw = int(1e9) + for user in node.users: + node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) + + aten = torch.ops.aten + prims = torch.ops.prims + + # compiler == "nvfuser" is the default set of recomputable ops + default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501,B950 + view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] + if compiler == "inductor": + default_recomputable_ops += [prims.div, prims.convert_element_type, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.arange, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum, aten.cumsum] # noqa: E501,B950 + view_ops += [aten.view, aten.slice, aten.permute, aten.t, aten.transpose, prims.broadcast_in_dim, aten.expand, aten.as_strided] + # Natalia said that we should allow recomputing indexing :) + default_recomputable_ops += [aten.index] + default_recomputable_ops += view_ops + + default_recomputable_ops += pointwise_ops() + + default_recomputable_ops += [ + aten.zeros_like, + ] + + default_recomputable_ops += [ + method_to_operator(m) + for m in magic_methods + ] + + recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops) + + random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] + compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501,B950 + + unrecomputable_ops = random_ops + compute_intensive_ops + + fusible_ops = recomputable_ops | set(random_ops) + + # The node match fusible pattern in byteir, it shoudle be recomputed. + byteir_recompute_nodes = get_byteir_recompute_nodes(full_bw_graph, required_fw_nodes) + + if AOT_PARTITIONER_DEBUG: + joint_module_ops = { + str(node.target._overloadpacket) + for node in joint_module.graph.nodes + if node.op == "call_function" and hasattr(node.target, "_overloadpacket") + } + ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops} + print("Ops banned from rematerialization: ", ops_ignored) + print() + + AGGRESSIVE_RECOMPUTATION = False + + def is_materialized_backwards(node): + cur_nodes = {node} + while len(cur_nodes) > 0: + cur = cur_nodes.pop() + for user in cur.users: + if user not in required_fw_nodes and not is_fusible(cur, user): + return True + if user not in required_fw_nodes and get_aten_target(user) in view_ops: + cur_nodes.add(user) + + return False + + def ban_recomputation(node): + if node in byteir_recompute_nodes: + return False + if "recompute" in node.meta: + return node.meta["recompute"] == 0 + elif AGGRESSIVE_RECOMPUTATION: + return (node.op == 'call_function' and get_aten_target(node) in unrecomputable_ops) + else: + if node.op != 'call_function': + return False + if get_aten_target(node) not in recomputable_ops: + return True + if node.target == operator.getitem: + return False + if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: + return False + + # If a node *must* be materialized in the backwards pass, then we + # should never recompute it. This is a pretty subtle point. In + # general, the assumption we make is that recomputing a node in the + # backwards pass is "free". However, if a node must be materialized + # in the backwards pass, then recomputing it is never free. + if is_materialized_backwards(node): + return True + + # Arbitrary hack that sometimes seems to help things. The above + # modification appears to have made this heuristic a lot less critical + # for performance. + # TODO: Investigate why this hack helps. + # TODO: Investigate the interaction with compiler assisted + # activation checkpointing. Removing the heuristic improves both + # memory footprint and speedup. + if not graph_has_recomputable_ops: + if compiler == "inductor" and node.dist_from_bw > config.max_dist_from_bw: + return True + # If the output of an op is 4x smaller (arbitrary choice), + # then we don't allow recomputation. + input_tensors_size = sum(_size_of(i) for i in node.args if isinstance(i, fx.Node)) + output_size = _size_of(node) + return (output_size * 4 < input_tensors_size) + + def is_fusible(a, b): + return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops + + def is_materialized(node): + if node.op == 'placeholder': + return True + + return not all(is_fusible(node, user) for user in node.users) + + def is_byteir_fusible(node): + if get_aten_target(node) in [aten.transpose, aten.t]: + return all(get_aten_target(user) in compute_intensive_ops for user in node.users) + return False + + def get_node_weight(node) -> int: + + mem_sz = _size_of(node) + + # Heuristic to bias towards nodes closer to the backwards pass + # Complete guess about current value + mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) + # mem_sz = int(mem_sz + node.dist_from_bw) + + if is_materialized(node): + return mem_sz + else: + return mem_sz * 2 + + nx_graph = nx.DiGraph() + + for node in full_bw_graph.nodes: + if node.op == 'output': + continue + + if node in required_bw_nodes: + nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) + continue + + if _is_primal(node) or _is_fwd_seed_offset(node): + nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) + + # If a node can't be recomputed (too expensive or involves randomness), + # we prevent it from being recomputed by adding an inf edge to the source + # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. + if ban_recomputation(node) and node in required_fw_nodes: + nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) + + # Checks if a node is actually a tuple. Can be simplified to just an isisinstance check if we always use faketensors. + is_non_tensor_node = (('val' not in node.meta and 'tensor_meta' not in node.meta) or + ('val' in node.meta and not isinstance(node.meta['val'], torch.Tensor))) + + if is_sym_node(node): + weight = sym_node_size(node) + elif is_non_tensor_node: + weight = math.inf + else: + weight = get_node_weight(node) + + # Creates the weights on the "node" edge + nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) + for user in node.users: + nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) + + for node in byteir_recompute_nodes: + nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) + + try: + cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") + except Exception: + print('Failed to compute min-cut on following graph:') + print('\n'.join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) + raise + + reachable, non_reachable = partition + cutset = set() + for u, nbrs in ((n, nx_graph[n]) for n in reachable): + cutset.update((u, v) for v in nbrs if v in non_reachable) + + cut_nodes = set() + for node_in, node_out in cutset: + assert node_in[:-3] == node_out[:-4] + node_name = node_in[:-3] + cut_nodes.add(node_name) + + # To make this stuff deterministic + node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]) + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes = list(filter(lambda n: is_sym_node(n), saved_values)) + saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols + fw_module, bw_module = _extract_fwd_bwd_modules( + joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs) + + + if graph_has_recomputable_ops: + if graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + joint_module, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + if AOT_PARTITIONER_DEBUG: + print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9) + fw_module_nodes = {node.name for node in fw_module.graph.nodes if node.op == 'call_function'} + bw_module_nodes = {node.name for node in bw_module.graph.nodes if node.op == 'call_function'} + remat_nodes = fw_module_nodes & bw_module_nodes + + counts = defaultdict(int) + for node in fw_module.graph.nodes: + if node.name in remat_nodes and hasattr(node.target, '_overloadpacket'): + counts[str(node.target._overloadpacket)] += 1 + print(f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}") + print("Count of Ops Rematerialized: ", sorted(counts.items(), key=lambda x: x[1], reverse=True)) + return fw_module, bw_module + + +def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph", clear_meta=True): + if clear_meta: + new_graph = copy.deepcopy(traced.graph) + traced = fx.GraphModule(traced, new_graph) + for node in traced.graph.nodes: + node.meta = {} + base, ext = os.path.splitext(fname) + if not ext: + ext = ".svg" + print(f"Writing FX graph to file: {base}{ext}") + g = graph_drawer.FxGraphDrawer(traced, figname) + x = g.get_main_dot_graph() + getattr(x, "write_" + ext.lstrip("."))(f"{base}{ext}") + + +def draw_joint_graph(graph, joint_inputs, file_name="full_graph.png"): + draw_graph(graph, file_name) + return default_partition(graph, joint_inputs) diff --git a/frontends/torch-frontend/third_party/patches/einsum.patch b/frontends/torch-frontend/third_party/patches/einsum.patch index 3b32cd0f5..4f7b913f9 100644 --- a/frontends/torch-frontend/third_party/patches/einsum.patch +++ b/frontends/torch-frontend/third_party/patches/einsum.patch @@ -1,8 +1,8 @@ diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td -index 09147dc8..4b69e9cd 100644 +index efdb89fa..c87de79e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td -@@ -7801,6 +7801,31 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [ +@@ -4834,6 +4834,31 @@ def Torch_AtenAddmmOp : Torch_Op<"aten.addmm", [ }]; } @@ -31,27 +31,28 @@ index 09147dc8..4b69e9cd 100644 + }]; +} + - def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ + def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp -index 4dafed1d..a3c908bf 100644 +index 558e31c6..fefc337e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp -@@ -5061,6 +5061,311 @@ public: +@@ -5022,6 +5022,460 @@ public: }; } // namespace +namespace { -+// Decompose AtenEinsumOp to AtenMmOp or AtenBmmOp -+// Step 1: split input equation to input/result tokens and find batchingDims and -+// contractingDims for future use -+// Step 2: transpose the input tensors to [batchingDims[0,1,2], -+// otherDims[0,1,2], contractingDims[0,1,2]] -+// Step 3: reshape the input tensors, the final shape should -+// be[batchingDims, otherDims, contractingDims] -+// Step 4: use AtenMatmulOp to get the result, loop util we get the final -+// result ++// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce ++// operation and permute operation. Currently, this pass doesn't support ++// Hadamard product. The basic idea is that: ++// Step 1: split the string equation to input/result tokens and find ++// batchingDims, contractingDims, otherDims and reduceDims. ++// Step 2: permute and reshape input tensors suitable ++// for matmul operations. ++// Step 3: use AtenMatmulOp to get the result. ++// Step 4: iteratively execute step 2 & 3 until we get the final result. ++// Step 5: perform remaining permute and reduce operations. +// notice: support static shape only + +static bool parseEquation(const std::string &equation, @@ -85,95 +86,358 @@ index 4dafed1d..a3c908bf 100644 + return true; +} + -+// Prepare Tensor for Matmul Operations, we will transpose the input tensor -+// to make it in order as [batchingDims, otherDims, contractingDims] -+// example: bcwd,bcdh->bcwh -+// Step1 : [b,c,h,d] -+// Step2 : [b*c,h,d] -+// Step3 : [e(=b*c), h, d] -+static Value prepareTensorForMatmulOperations( -+ PatternRewriter &rewriter, Operation *op, Value inputTensor, -+ const SmallVector &shape, const SmallVector &contractingDims, -+ const SmallVector &batchingDims, SmallVector &finalShape, -+ const SmallVector &tokens) { -+ SmallVector otherDims; -+ Value middleDimProduct = -+ rewriter.create(op->getLoc(), rewriter.getI64IntegerAttr(1)); -+ for (size_t i = 0; i < shape.size(); ++i) { -+ if (std::find(batchingDims.begin(), batchingDims.end(), i) == -+ batchingDims.end() && -+ std::find(contractingDims.begin(), contractingDims.end(), i) == -+ contractingDims.end()) { -+ middleDimProduct = -+ rewriter.create(op->getLoc(), middleDimProduct, shape[i]); -+ otherDims.push_back(i); ++// classify every dim token into different categories. Note that although we ++// parse out reduce dims, we delay their execution until ++// `performLastPermuteAndReduce`. ++static void parseDimTokens( ++ SmallVector &lhsTokens, SmallVector &rhsTokens, ++ SmallVector &finalResultTokens, SmallVector &contractingDims, ++ SmallVector &lhsReduceDims, SmallVector &rhsReduceDims, ++ SmallVector &batchingDims, SmallVector &lhsOtherDims, ++ SmallVector &rhsOtherDims) { ++ llvm::SmallDenseSet lhsTokenSet(lhsTokens.begin(), lhsTokens.end()); ++ llvm::SmallDenseSet rhsTokenSet(rhsTokens.begin(), rhsTokens.end()); ++ llvm::SmallDenseSet finalResultTokenSet(finalResultTokens.begin(), ++ finalResultTokens.end()); ++ ++ for (size_t i = 0; i < lhsTokens.size(); ++i) { ++ bool rhsContains = rhsTokenSet.contains(lhsTokens[i]); ++ bool finalResultConatins = finalResultTokenSet.contains(lhsTokens[i]); ++ // batching dim ++ if (rhsContains && finalResultConatins) { ++ batchingDims.push_back(lhsTokens[i]); ++ // reduce dim of lhs ++ } else if (!rhsContains && !finalResultConatins) { ++ lhsReduceDims.push_back(lhsTokens[i]); ++ // other dim of lhs ++ } else if (finalResultConatins) { ++ lhsOtherDims.push_back(lhsTokens[i]); ++ // contracting dim of lhs ++ } else if (rhsContains) { ++ contractingDims.push_back(lhsTokens[i]); + } + } -+ int64_t otherDimsSize = otherDims.size(); -+ if (!batchingDims.empty()) { -+ int64_t usedOtherDim = 0; -+ Value batchingDimProduct = -+ rewriter.create(op->getLoc(), rewriter.getI64IntegerAttr(1)); -+ int64_t batchingDimsRank = batchingDims.size(); -+ for (int64_t i = 0; i < batchingDimsRank; ++i) { -+ batchingDimProduct = -+ rewriter.create(op->getLoc(), batchingDimProduct, -+ shape[batchingDims[i]]); -+ if (batchingDims[i] != i) { -+ Value batchingDim = -+ rewriter.create(op->getLoc(), -+ rewriter.getI64IntegerAttr( -+ batchingDims[i])); -+ Value indexDim = rewriter.create( -+ op->getLoc(), rewriter.getI64IntegerAttr(otherDims[usedOtherDim])); -+ inputTensor = rewriter.create( -+ op->getLoc(), op->getResultTypes(), inputTensor, batchingDim, indexDim); -+ usedOtherDim += 1; -+ } ++ ++ for (size_t i = 0; i < rhsTokens.size(); ++i) { ++ bool lhsContains = lhsTokenSet.contains(rhsTokens[i]); ++ bool finalResultConatins = finalResultTokenSet.contains(rhsTokens[i]); ++ // batching dim ++ if (lhsContains && finalResultConatins) { ++ // reduce dim of rhs ++ } else if (!lhsContains && !finalResultConatins) { ++ rhsReduceDims.push_back(rhsTokens[i]); ++ // other dim of rhs ++ } else if (finalResultConatins) { ++ rhsOtherDims.push_back(rhsTokens[i]); ++ // contracting dim of rhs ++ } else if (lhsContains) { ++ } ++ } ++} ++ ++static void generateIdealReusltDimTokens(SmallVector &batchingDims, ++ SmallVector &lhsOtherDims, ++ SmallVector &rhsOtherDims, ++ SmallVector &lhsReduceDims, ++ SmallVector &rhsReduceDims, ++ SmallVector &resultTokens) { ++ // generate ideal result dims, i.e., ++ // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims, ++ // *rhsReduceDims] ++ resultTokens.insert(resultTokens.end(), batchingDims.begin(), ++ batchingDims.end()); ++ resultTokens.insert(resultTokens.end(), lhsOtherDims.begin(), ++ lhsOtherDims.end()); ++ resultTokens.insert(resultTokens.end(), lhsReduceDims.begin(), ++ lhsReduceDims.end()); ++ resultTokens.insert(resultTokens.end(), rhsOtherDims.begin(), ++ rhsOtherDims.end()); ++ resultTokens.insert(resultTokens.end(), rhsReduceDims.begin(), ++ rhsReduceDims.end()); ++} ++ ++static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, ++ Value input, SmallVector &dimTokens, ++ SmallVector &batchingDims, ++ SmallVector &contractingDims, ++ SmallVector &otherDims, ++ SmallVector &reduceDims, bool isLhs) { ++ auto inputType = input.getType().cast(); ++ llvm::SmallDenseMap dimTokenMap; ++ for (size_t idx = 0; idx < dimTokens.size(); ++idx) { ++ dimTokenMap[dimTokens[idx]] = idx; ++ } ++ ++ SmallVector permuteVec; ++ auto appendDims = [&](SmallVector dimTokens) { ++ for (auto d : dimTokens) { ++ permuteVec.push_back(rewriter.create( ++ loc, rewriter.getI64IntegerAttr(dimTokenMap[d]))); ++ } ++ }; ++ ++ appendDims(batchingDims); ++ if (!isLhs) ++ appendDims(contractingDims); ++ appendDims(otherDims); ++ appendDims(reduceDims); ++ if (isLhs) ++ appendDims(contractingDims); ++ ++ Value dstDims = rewriter.create( ++ loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), ++ permuteVec); ++ auto outType = inputType.getWithSizesAndDtype(std::nullopt, ++ inputType.getOptionalDtype()); ++ return rewriter.create(loc, outType, input, dstDims); ++} ++ ++// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] => ++// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd] ++static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, ++ Value input, int64_t batchDimsLength, ++ int64_t contractingDimsLength, ++ int64_t otherDimsLength, ++ int64_t reduceDimsLength, bool isLhs) { ++ auto inputType = input.getType().cast(); ++ auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + ++ reduceDimsLength; ++ SmallVector inputShapeTensor; ++ for (auto i = 0; i < inputRank; ++i) { ++ inputShapeTensor.emplace_back(rewriter.create( ++ loc, input, ++ rewriter.create(loc, ++ rewriter.getI64IntegerAttr(i)))); ++ } ++ ++ SmallVector outShapeTensor; ++ Value constOne = ++ rewriter.create(loc, rewriter.getI64IntegerAttr(1)); ++ auto dimOffset = 0; ++ ++ auto appendDims = [&](int64_t dimLength) { ++ Value prod = constOne; ++ for (auto i = 0; i < dimLength; ++i) { ++ prod = rewriter.create(loc, prod, ++ inputShapeTensor[i + dimOffset]); + } -+ finalShape.push_back(batchingDimProduct); ++ outShapeTensor.emplace_back(prod); ++ dimOffset += dimLength; ++ }; ++ ++ appendDims(batchDimsLength); ++ if (!isLhs) ++ appendDims(contractingDimsLength); ++ appendDims(otherDimsLength + reduceDimsLength); ++ if (isLhs) ++ appendDims(contractingDimsLength); ++ ++ auto outShapeValue = rewriter.create( ++ loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), ++ outShapeTensor); ++ ++ auto outType = inputType.getWithSizesAndDtype(std::nullopt, ++ inputType.getOptionalDtype()); ++ return rewriter.create(loc, outType, input, ++ outShapeValue); ++} ++ ++static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, ++ Value lhs, SmallVector &lhsTokens, ++ Value rhs, SmallVector &rhsTokens, ++ Value &result, ++ SmallVector &resultTokens, ++ SmallVector &finalResultTokens) { ++ auto lhsType = lhs.getType().cast(); ++ auto rhsType = rhs.getType().cast(); ++ Type promotedDType; ++ ++ // promote dtype ++ if (lhsType.hasDtype() && rhsType.hasDtype()) { ++ auto lhsDtype = Torch::getScalarTypeForType(lhsType.getOptionalDtype()); ++ auto rhsDtype = Torch::getScalarTypeForType(rhsType.getOptionalDtype()); ++ auto promotedDTypeInt = ++ torch_upstream::promote_skip_undefined(lhsDtype, rhsDtype); ++ auto promotedDTypeIntValue = rewriter.create( ++ loc, rewriter.getI64IntegerAttr((int)promotedDTypeInt)); ++ auto promotedDTypeInfo = ++ getTypeForScalarType(rewriter.getContext(), promotedDTypeInt, ++ mlir::IntegerType::SignednessSemantics::Signed); ++ if (failed(promotedDTypeInfo)) ++ rewriter.notifyMatchFailure(loc, "Failed to get type for promoted dtype"); ++ promotedDType = *promotedDTypeInfo; ++ ++ auto falseValue = ++ rewriter.create(loc, rewriter.getBoolAttr(false)); ++ auto noneValue = rewriter.create(loc); ++ lhs = rewriter.create( ++ loc, ++ lhsType.getWithSizesAndDtype(lhsType.getOptionalSizes(), promotedDType), ++ lhs, promotedDTypeIntValue, falseValue, falseValue, noneValue); ++ rhs = rewriter.create( ++ loc, ++ rhsType.getWithSizesAndDtype(rhsType.getOptionalSizes(), promotedDType), ++ rhs, promotedDTypeIntValue, falseValue, falseValue, noneValue); ++ } else { ++ promotedDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() : rhsType.getOptionalDtype(); ++ } ++ ++ llvm::SmallDenseMap lhsDimShapeMap; ++ for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { ++ char d = lhsTokens[idx]; ++ lhsDimShapeMap[d] = rewriter.create( ++ loc, lhs, ++ rewriter.create(loc, ++ rewriter.getI64IntegerAttr(idx))); + } -+ finalShape.push_back(middleDimProduct); -+ if (!contractingDims.empty()) { -+ int64_t usedOtherDim = 1; -+ int64_t rank = tokens.size(); -+ Value contractingDimProduct = -+ rewriter.create(op->getLoc(), rewriter.getI64IntegerAttr(1)); -+ int64_t contractingDimsRank = contractingDims.size(); -+ for (int64_t i = contractingDimsRank - 1; i > -1; --i) { -+ contractingDimProduct = -+ rewriter.create(op->getLoc(), contractingDimProduct, -+ shape[contractingDims[i]]); -+ if (contractingDims[i] != rank - contractingDimsRank + i) { -+ Value contractingDim = -+ rewriter.create(op->getLoc(), -+ rewriter.getI64IntegerAttr( -+ contractingDims[i])); -+ Value indexDim = rewriter.create( -+ op->getLoc(), rewriter.getI64IntegerAttr( -+ otherDims[otherDimsSize - usedOtherDim])); -+ inputTensor = rewriter.create( -+ op->getLoc(), op->getResultTypes(), inputTensor, contractingDim, indexDim); -+ usedOtherDim += 1; ++ llvm::SmallDenseMap rhsDimShapeMap; ++ for (size_t idx = 0; idx < rhsTokens.size(); ++idx) { ++ char d = rhsTokens[idx]; ++ rhsDimShapeMap[d] = rewriter.create( ++ loc, rhs, ++ rewriter.create(loc, ++ rewriter.getI64IntegerAttr(idx))); ++ } ++ ++ // parse batch, contracting, other, reduce dims of lhs and rhs ++ SmallVector contractingDims; ++ SmallVector lhsReduceDims; ++ SmallVector rhsReduceDims; ++ SmallVector lhsOtherDims; ++ SmallVector rhsOtherDims; ++ SmallVector batchingDims; ++ parseDimTokens(lhsTokens, rhsTokens, finalResultTokens, contractingDims, ++ lhsReduceDims, rhsReduceDims, batchingDims, lhsOtherDims, ++ rhsOtherDims); ++ ++ llvm::SmallDenseMap outDimShapeMap; ++ auto generateOutDimShapeMap = [&](SmallVector &dims) { ++ for (auto d : dims) { ++ bool lhsContains = lhsDimShapeMap.count(d) > 0; ++ bool rhsContains = rhsDimShapeMap.count(d) > 0; ++ if (lhsContains && rhsContains) { ++ outDimShapeMap[d] = rewriter.create( ++ loc, lhsDimShapeMap[d], rhsDimShapeMap[d]); ++ } else if (lhsContains) { ++ outDimShapeMap[d] = lhsDimShapeMap[d]; ++ } else if (rhsContains) { ++ outDimShapeMap[d] = rhsDimShapeMap[d]; + } + } -+ finalShape.push_back(contractingDimProduct); ++ }; ++ ++ generateOutDimShapeMap(contractingDims); ++ generateOutDimShapeMap(batchingDims); ++ generateOutDimShapeMap(lhsReduceDims); ++ generateOutDimShapeMap(rhsReduceDims); ++ generateOutDimShapeMap(lhsOtherDims); ++ generateOutDimShapeMap(rhsOtherDims); ++ ++ if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 && ++ rhsOtherDims.size() == 0) { ++ return rewriter.notifyMatchFailure( ++ loc, "Hadamard product is currently not supported"); ++ } ++ ++ // shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] ++ lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims, ++ contractingDims, lhsOtherDims, lhsReduceDims, ++ true); ++ // shape: [*batchingDims, *rhsContractingDims, *rhsOtherDims, *rhsReduceDims] ++ rhs = permuteTensorForMatmul(rewriter, loc, rhs, rhsTokens, batchingDims, ++ contractingDims, rhsOtherDims, rhsReduceDims, ++ false); ++ // shape: [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd] ++ lhs = collapseDimForMatmul(rewriter, loc, lhs, batchingDims.size(), ++ contractingDims.size(), lhsOtherDims.size(), ++ lhsReduceDims.size(), true); ++ // shape: [batchingDimsProd, rhsContractingDimsProd, rhsOtherDimsProd] ++ rhs = collapseDimForMatmul(rewriter, loc, rhs, batchingDims.size(), ++ contractingDims.size(), rhsOtherDims.size(), ++ rhsReduceDims.size(), false); ++ ++ // perform matmul ++ auto outType = ++ lhsType.getWithSizesAndDtype(std::nullopt, promotedDType); ++ result = rewriter.create(loc, outType, lhs, rhs); ++ ++ // generate ideal result dims. ++ generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims, ++ lhsReduceDims, rhsReduceDims, resultTokens); ++ ++ // reshape matmul result to ideal shape: ++ // [batchingDimsProd, lhsOtherDimsProd, rhsOtherDimsProd] => ++ // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims, ++ // *rhsReduceDims] ++ SmallVector outShapeTensors; ++ for (char d : resultTokens) { ++ outShapeTensors.emplace_back(outDimShapeMap[d]); + } -+ return inputTensor; ++ ++ auto outResultShape = rewriter.create( ++ loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())), ++ outShapeTensors); ++ result = rewriter.create( ++ loc, ++ lhsType.getWithSizesAndDtype(std::nullopt, promotedDType), ++ result, outResultShape); ++ return success(); +} + -+static Value createReshapedTensor(PatternRewriter &rewriter, Location loc, -+ Operation* op, Type tensorType, Value tensor, -+ SmallVector &shape) { -+ auto listType = Torch::ListType::get(Torch::IntType::get(op->getContext())); -+ Value reshapedDims = -+ rewriter.create(loc, listType, shape); -+ return rewriter.create(loc, tensorType, tensor, reshapedDims); ++static Value performLastReduceAndPermute(PatternRewriter &rewriter, ++ Location loc, Type outType, Value input, ++ SmallVector &inputTokens, ++ SmallVector &outTokens) { ++ auto inputType = input.getType().cast(); ++ ++ llvm::SmallDenseSet outTokenSet(outTokens.begin(), outTokens.end()); ++ SmallVector sumDims; ++ llvm::SmallDenseMap inputDimToIdx; ++ int64_t idx = 0; ++ for (size_t i = 0; i < inputTokens.size(); ++i) { ++ char d = inputTokens[i]; ++ if (!outTokenSet.contains(d)) { ++ sumDims.emplace_back(i); ++ } else { ++ inputDimToIdx[d] = idx++; ++ } ++ } ++ ++ if (sumDims.size() > 0) { ++ SmallVector sumDimsTensor; ++ for (auto d : sumDims) { ++ sumDimsTensor.emplace_back(rewriter.create( ++ loc, rewriter.getI64IntegerAttr(d))); ++ } ++ auto sumDimsListValue = rewriter.create( ++ loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), ++ sumDimsTensor); ++ auto falseValue = rewriter.create( ++ loc, rewriter.getBoolAttr(false)); ++ auto noneValue = rewriter.create(loc); ++ input = rewriter.create( ++ loc, ++ inputType.getWithSizesAndDtype(std::nullopt, ++ inputType.getOptionalDtype()), ++ input, sumDimsListValue, falseValue, noneValue); ++ } ++ ++ SmallVector permuteDimsTensor; ++ for (auto d : outTokens) { ++ permuteDimsTensor.emplace_back(rewriter.create( ++ loc, rewriter.getI64IntegerAttr(inputDimToIdx[d]))); ++ } ++ auto permuteDimsListValue = rewriter.create( ++ loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), ++ permuteDimsTensor); ++ auto out = rewriter.create(loc, outType, input, ++ permuteDimsListValue); ++ return out; +} + + +class DecomposeAtenEinsumOp : public OpRewritePattern { -+ public: ++public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEinsumOp op, + PatternRewriter &rewriter) const override { @@ -185,162 +449,47 @@ index 4dafed1d..a3c908bf 100644 + SmallVector resultTokens; + SmallVector> inputTokens; + if (!parseEquation(equation, inputTokens, resultTokens)) { -+ return rewriter.notifyMatchFailure(op, "Unexpected character in equations encountered"); ++ return rewriter.notifyMatchFailure( ++ op, "Unexpected character in equations encountered"); + } + + SmallVector inputTensors; -+ SmallVector> inputShapes; + if (!getListConstructElements(op.getTensors(), inputTensors)) { + return rewriter.notifyMatchFailure( + op, "input should comes from a PrimListConstructOp"); + } + -+ for (size_t i = 0; i < inputTensors.size(); i++) { -+ BaseTensorType tensorType = -+ inputTensors[i].getType().cast(); -+ if (!tensorType.hasSizes()) { -+ return rewriter.notifyMatchFailure( -+ op, "unimplemented: input tensor must have known sizes"); -+ } -+ ArrayRef inputShape = tensorType.getSizes(); -+ SmallVector inputValueShape; -+ for (unsigned j = 0; j < inputShape.size(); j++) { -+ inputValueShape.push_back(rewriter.create( -+ loc, inputTensors[i], -+ rewriter.create( -+ loc, rewriter.getI64IntegerAttr(j)))); -+ } -+ inputShapes.push_back(inputValueShape); -+ } -+ -+ auto collectOperandDims = [resultTokens]( -+ const SmallVector operandShape, -+ const SmallVector operandTokens, -+ const SmallVector others, -+ SmallVectorImpl &contractingDims, -+ SmallVectorImpl &batchingDims, -+ SmallVector &dotResultTokens, -+ SmallVector &dotResultShape) { -+ llvm::SmallDenseSet othersSet(others.begin(), others.end()); -+ llvm::SmallDenseSet resultTokensSet(resultTokens.begin(), -+ resultTokens.end()); -+ for (const auto &en : llvm::enumerate(operandTokens)) { -+ bool isResultToken = resultTokensSet.contains(en.value()); -+ bool isOtherToken = othersSet.contains(en.value()); -+ if (!isResultToken && isOtherToken) { -+ contractingDims.push_back(en.index()); -+ } else if (isOtherToken) { -+ batchingDims.push_back(en.index()); -+ } else { -+ dotResultTokens.push_back(en.value()); -+ dotResultShape.push_back(operandShape[en.index()]); -+ } -+ } ++ auto allTensorHasSizes = [](Value tensor) { ++ auto type = tensor.getType().dyn_cast(); ++ if (!type || !type.hasSizes()) ++ return false; ++ return true; + }; + -+ Value constZero = -+ rewriter.create(loc, rewriter.getI64IntegerAttr(0)); -+ Value constOne = -+ rewriter.create(loc, rewriter.getI64IntegerAttr(1)); -+ Value constTwo = -+ rewriter.create(loc, rewriter.getI64IntegerAttr(2)); -+ if (inputTensors.size() == 1) { -+ return rewriter.notifyMatchFailure( -+ op, "unimplemented: single input tensor is not supported"); ++ if (!llvm::all_of(inputTensors, allTensorHasSizes)) { ++ return rewriter.notifyMatchFailure(op, ++ "all input tensors should have sizes"); + } -+ while (inputTensors.size() > 1) { -+ SmallVector lhsContractingDims, lhsBatchingDims, -+ rhsContractingDims, rhsBatchingDims; -+ SmallVector dotResultTokens; -+ SmallVector dotResultShape; -+ SmallVector lhsShape = inputShapes[0]; -+ SmallVector rhsShape = inputShapes[1]; -+ SmallVector lhsTokens = inputTokens[0]; -+ SmallVector rhsTokens = inputTokens[1]; -+ Value lhsTensor = inputTensors[0]; -+ Value rhsTensor = inputTensors[1]; -+ // Step 1: split input equation to input/result tokens -+ collectOperandDims(lhsShape, lhsTokens, rhsTokens, lhsContractingDims, -+ lhsBatchingDims, dotResultTokens, dotResultShape); -+ collectOperandDims(rhsShape, rhsTokens, lhsTokens, rhsContractingDims, -+ rhsBatchingDims, dotResultTokens, dotResultShape); -+ // Prepend batch tokens. -+ for (const auto &it : llvm::enumerate(lhsBatchingDims)) { -+ char batchingToken = lhsTokens[it.value()]; -+ Value batchingShapeDim = lhsShape[it.value()]; -+ dotResultTokens.insert(dotResultTokens.begin() + it.index(), -+ batchingToken); -+ dotResultShape.insert(dotResultShape.begin() + it.index(), -+ batchingShapeDim); -+ } -+ // Lowering to dot_general does not support a mismatch between the number -+ // of result dims and the number of non-contracting dims. -+ -+ SmallVector lhsFinalShape, rhsFinalShape; -+ SmallVector finalShape = dotResultShape; -+ // Step 2: transpose the input tensors to [batchingDims[0,1,2], -+ // otherDims[0,1,2], contractingDims[0,1,2]] -+ lhsTensor = prepareTensorForMatmulOperations(rewriter, op, lhsTensor, lhsShape, -+ lhsContractingDims, lhsBatchingDims, -+ lhsFinalShape, lhsTokens); -+ rhsTensor = prepareTensorForMatmulOperations(rewriter, op, rhsTensor, rhsShape, -+ rhsContractingDims, rhsBatchingDims, -+ rhsFinalShape, rhsTokens); -+ -+ // Step 3: reshape the input tensors, the final shape should -+ // be[batchingDims, otherDims, contractingDims] -+ auto listType = Torch::ListType::get(Torch::IntType::get(op->getContext())); -+ Value lhsReshapedDims = -+ rewriter.create(loc, listType, lhsFinalShape); -+ Value lhs = rewriter.create(loc, op.getType(), lhsTensor, lhsReshapedDims); -+ Value rhsReshapedDims = -+ rewriter.create(loc, listType, rhsFinalShape); -+ Value rhs = rewriter.create(loc, op.getType(), rhsTensor, rhsReshapedDims); -+ Value result; -+ -+ // Step 4: use AtenMatmulOp to get the result, loop util we -+ // get the final result -+ if (!rhsContractingDims.empty() && !rhsBatchingDims.empty()){ -+ rhs = rewriter.create(loc, op.getType(), rhs, constOne, constTwo); -+ } else if (!rhsContractingDims.empty()){ -+ rhs = rewriter.create(loc, op.getType(), rhs, constZero, constOne); -+ } -+ result = rewriter.create(loc, op.getType(), lhs, rhs); -+ result = createReshapedTensor(rewriter, loc, op, op.getType(), result, finalShape); -+ -+ inputTensors.erase(inputTensors.begin(), inputTensors.begin() + 2); -+ inputTokens.erase(inputTokens.begin(), inputTokens.begin() + 2); -+ inputShapes.erase(inputShapes.begin(), inputShapes.begin() + 2); -+ inputTensors.push_back(result); -+ inputTokens.push_back(dotResultTokens); -+ inputShapes.push_back(dotResultShape); -+ if (inputTokens.size() == 1) { -+ // Lowering to dot_general does not support a mismatch between the number -+ // of result dims and the number of non-contracting dims. -+ if (dotResultTokens.size() != resultTokens.size()) { -+ return rewriter.notifyMatchFailure(op, -+ "rank reducing einsum not supported"); -+ } -+ int64_t resultSize = 0; -+ for (char resultToken : resultTokens) { -+ auto *foundIt = std::find(dotResultTokens.begin(), dotResultTokens.end(), -+ resultToken); -+ if (foundIt == dotResultTokens.end()) { -+ return rewriter.notifyMatchFailure( -+ op, "result token not found in operands"); -+ } -+ auto resultIndex = std::distance(dotResultTokens.begin(), foundIt); -+ if (resultIndex > resultSize) { -+ Value first = rewriter.create(loc, rewriter.getI64IntegerAttr(resultSize)); -+ Value second = rewriter.create(loc, rewriter.getI64IntegerAttr(resultIndex)); -+ result = rewriter.create(loc, op.getType(), result, first, second); -+ } -+ resultSize += 1; -+ } -+ // The dot_general is already in an appropriate result order. -+ rewriter.replaceOp(op, ValueRange{result}); ++ ++ SmallVector lhsTokens = inputTokens[0]; ++ Value lhs = inputTensors[0]; ++ Value result; ++ ++ for (size_t i = 1; i < inputTensors.size(); ++i) { ++ auto rhs = inputTensors[i]; ++ auto rhsTokens = inputTokens[i]; ++ SmallVector outTokens; ++ if (failed(performMatmul(rewriter, loc, lhs, lhsTokens, rhs, rhsTokens, ++ result, outTokens, resultTokens))) { ++ return failure(); + } ++ lhs = result; ++ lhsTokens = outTokens; + } ++ ++ result = performLastReduceAndPermute(rewriter, loc, op.getType(), lhs, lhsTokens, ++ resultTokens); ++ rewriter.replaceOp(op, result); + return success(); + } +}; @@ -348,16 +497,16 @@ index 4dafed1d..a3c908bf 100644 + + namespace { - class DecomposeComplexOpsPass - : public DecomposeComplexOpsBase { -@@ -5164,6 +5469,7 @@ public: - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + // Unconditionally decompose `aten.tile` into `aten.repeat`. + class DecomposeAtenTileOp : public OpRewritePattern { +@@ -5221,6 +5675,7 @@ public: + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 76119828..179440c6 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp diff --git a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp index a96b17087..3760e78c7 100644 --- a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp +++ b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp @@ -970,6 +970,10 @@ class ConvertFlashAttnFwdOp : public OpConversionPattern { Type softmaxLseTy = op.getResult(5).getType(); Type softmaxTy = op.getResult(6).getType(); Type rngTy = op.getResult(7).getType(); + // Do not need softmax return if there's no use + if (op.getResult(6).use_empty()) + returnSoftmax = false; + SmallVector resultTypes; if (failed(getTypeConverter()->convertTypes( {outputPadTy, softmaxLseTy, softmaxTy, rngTy}, resultTypes))) { diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_attn_rewrite.py b/frontends/torch-frontend/torch-frontend/python/test/test_attn_rewrite.py index 99bf2fadd..ad64e0508 100644 --- a/frontends/torch-frontend/torch-frontend/python/test/test_attn_rewrite.py +++ b/frontends/torch-frontend/torch-frontend/python/test/test_attn_rewrite.py @@ -183,3 +183,21 @@ def test_flash_attn_opt_pattern(): torch.testing.assert_close(golden_loss, flash_loss, atol=1e-4, rtol=1e-6) torch.testing.assert_close(golden_logits, flash_logits, atol=3e-3, rtol=1e-6) + + +def test_flash_attn_llama_inference_pattern(): + config = transformers.LlamaConfig(num_hidden_layers=4) + model = transformers.LlamaForCausalLM(config=config).to("cuda") + model.eval() + + input, label = make_data(model, "cuda") + trace_data = [input] + + from torch.fx.experimental.proxy_tensor import make_fx + from torch_frontend import preprocess_fx_graph + # module = torch.jit.trace(model, trace_data, check_trace=False) + with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16): + fx_g = make_fx(model)(*trace_data) + fx_g = preprocess_fx_graph(fx_g) + all_formatted = "\n".join([n.format_node() for n in fx_g.graph.nodes]) + FileCheck().check("call_function").check("torch.ops.byteir.flash_attn_fwd").run(all_formatted) diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_fx_utils.py b/frontends/torch-frontend/torch-frontend/python/test/test_fx_utils.py new file mode 100644 index 000000000..d6143c3a1 --- /dev/null +++ b/frontends/torch-frontend/torch-frontend/python/test/test_fx_utils.py @@ -0,0 +1,18 @@ +import torch +import torch.fx as fx +import torch_frontend +from torch_frontend.fx_utils import _replace_aten_full_arugment + +class FullModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.ops.aten.full(x.shape, True, dtype=torch.bool) + return y + + +def test_full_bool_pattern(): + fx_g = fx.symbolic_trace(FullModule()) + fx_g = _replace_aten_full_arugment(fx_g) + module = torch.jit.script(fx_g) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py index 7c6fda4a8..86d885215 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/__init__.py @@ -25,8 +25,7 @@ del importlib del _torch_frontend_registry -from .ts_utils import register_decomposition_in_torchscript -from .fx_utils import list_decomposed_ops, preprocess_fx_graph +from .fx_utils import list_decomposed_ops, preprocess_fx_graph, get_none_indices from .convert_to_mhlo import convert_to_mhlo_via_torch_mlir, compile from .flash_attn_op import replace_flash_attn from .fx_rewrite import fx_replace_attn_pattern diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_utils.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_utils.py index eb075a648..20d95e2b5 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_utils.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_utils.py @@ -76,6 +76,34 @@ def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: return removed_indexes +# note: torch.jit.script doesn't support torch.ops.aten.full([2, 1, 1, 128], True, dtype = torch.bool), replace it with torch.ops.aten.full([2, 1, 1, 128], 1, dtype = torch.bool) +def _replace_aten_full_arugment(fx_g: torch.fx.GraphModule) -> torch.fx.GraphModule : + def get_aten_target(node): + if hasattr(node.target, 'overloadpacket'): + return node.target.overloadpacket + return node.target + + nodes = [] + for node in fx_g.graph.nodes: + if get_aten_target(node) == torch.ops.aten.full: + if node.args[1] == True or node.args[1] == False: + nodes.append(node) + for node in nodes: + if node.args[1] == True: + with fx_g.graph.inserting_after(node): + new_node = fx_g.graph.call_function(torch.ops.aten.full, args=(node.args[0], 1), kwargs=node.kwargs) + node.replace_all_uses_with(new_node) + fx_g.graph.erase_node(node) + if node.args[1] == False: + with fx_g.graph.inserting_after(node): + new_node = fx_g.graph.call_function(torch.ops.aten.full, args=(node.args[0], 0), kwargs=node.kwargs) + node.replace_all_uses_with(new_node) + fx_g.graph.erase_node(node) + fx_g.graph.lint() + fx_g.recompile() + return fx_g + + def threshold_backward_pattern(grad_output, inp, threshold): return torch.ops.aten.threshold_backward(grad_output, inp, threshold) @@ -96,6 +124,61 @@ def unsafe_index_put_pattern(self, indices, values, accumulate): def unsafe_index_put_replacement(self, indices, values, accumulate): return torch.ops.aten.index_put_.hacked_twin(self, indices, values, accumulate) +# LLaMA aten attention op pattern +def LLaMAAttnPattern(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim): + transpose_3 = torch.ops.aten.transpose.int(key, 2, 3) + expand_2 = torch.ops.aten.expand.default(query, [batch, num_head, seq_len, head_dim]) + clone = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format) + _unsafe_view_3 = torch.ops.aten._unsafe_view.default(clone, [fused_batch, seq_len, head_dim]) + expand_3 = torch.ops.aten.expand.default(transpose_3, [batch, num_head, head_dim, seq_len]) + clone_1 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format) + _unsafe_view_4 = torch.ops.aten._unsafe_view.default(clone_1, [fused_batch, head_dim, seq_len]) + bmm = torch.ops.aten.bmm.default(_unsafe_view_3, _unsafe_view_4) + _unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm, [batch, num_head, seq_len, seq_len]) + div = torch.ops.aten.div.Tensor(_unsafe_view_5, inv_scale) + add_5 = torch.ops.aten.add.Tensor(div, attn_mask) + maximum = torch.ops.aten.maximum.default(add_5, min_val) + _softmax = torch.ops.aten._softmax.default(maximum, -1, False) + _to_copy_10 = torch.ops.aten._to_copy.default(_softmax, dtype = torch.float16) + expand_4 = torch.ops.aten.expand.default(_to_copy_10, [batch, num_head, seq_len, seq_len]) + view_8 = torch.ops.aten.view.default(expand_4, [fused_batch, seq_len, seq_len]); expand_4 = None + expand_5 = torch.ops.aten.expand.default(value, [batch, num_head, seq_len, head_dim]) + clone_2 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format) + _unsafe_view_6 = torch.ops.aten._unsafe_view.default(clone_2, [fused_batch, seq_len, head_dim]) + bmm_1 = torch.ops.aten.bmm.default(view_8, _unsafe_view_6) + _unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm_1, [batch, num_head, seq_len, head_dim]) + return _softmax, _unsafe_view_5 + + +def LLaMAAttnReplacement(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim): + out, q_pad, k_pad, v_pad, out_pad, softmax_lse, S_dmask, rng_state = torch.ops.byteir.flash_attn_fwd( + query, + key, + value, + 0.0, + 1.0/inv_scale, + True, + True + ) + return S_dmask, out + + +def get_none_indices(fx_g: torch.fx.GraphModule) -> List[int]: + none_indices = [] + for node in fx_g.graph.nodes: + if node.op == "output": + assert len(node.args) == 1, "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, (list, tuple)): + node_arg = list(node_arg) + node_args_len = len(node_arg) + for i in range(node_args_len): + if node_arg[i] is None: + none_indices.append(i) + break + return none_indices + + def list_decomposed_ops(): return [ torch.ops.aten._native_batch_norm_legit_functional, @@ -108,15 +191,18 @@ def list_decomposed_ops(): torch.ops.aten.tril ] + def preprocess_fx_graph(fx_graph: torch.fx.GraphModule): if _returns_nothing(fx_graph): return fx_graph torch.fx.replace_pattern(fx_graph, squeeze_dims_pattern, squeeze_dims_replacement) torch.fx.replace_pattern(fx_graph, unsafe_index_put_pattern, unsafe_index_put_replacement) + torch.fx.replace_pattern(fx_graph, LLaMAAttnPattern, LLaMAAttnReplacement) was_unwrapped = _unwrap_single_tuple_return(fx_graph) was_list_replaced = _list_return_to_tuple_return(fx_graph) removed_none_indexes = _remove_nones(fx_graph) strip_overloads(fx_graph) torch.fx.replace_pattern(fx_graph, threshold_backward_pattern, threshold_backward_replacement) + fx_graph = _replace_aten_full_arugment(fx_graph) return fx_graph diff --git a/runtime/include/brt/core/framework/op_accessor.h b/runtime/include/brt/core/framework/op_accessor.h index 8e3bb0c5d..352a63228 100644 --- a/runtime/include/brt/core/framework/op_accessor.h +++ b/runtime/include/brt/core/framework/op_accessor.h @@ -67,6 +67,9 @@ class OpAccessor { template T GetAttrAsSplatValue(const std::string &name) const; + template + std::vector GetAttrAsVector(const std::string &name) const; + std::string GetUID() const; static int64_t GetNumElementsOfShape(const Shape &shape); diff --git a/runtime/lib/backends/cuda/providers/default/ait/ait.cc b/runtime/lib/backends/cuda/providers/default/ait/ait.cc index e0e9e56e8..728a9aa70 100644 --- a/runtime/lib/backends/cuda/providers/default/ait/ait.cc +++ b/runtime/lib/backends/cuda/providers/default/ait/ait.cc @@ -370,7 +370,8 @@ AITOpKernel::AITOpKernel(const OpKernelInfo &info) std::string lib_path = brt::ir::GetParentPath(ir_path); lib_path += accessor.GetAttrAsString(std::string("ait_lib_file")); aitLibHdl = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_LOCAL); - BRT_ENFORCE(aitLibHdl != nullptr, "AIT lib .so load failed"); + std::string msg = std::string("AIT lib ") + lib_path + " load failed"; + BRT_ENFORCE(aitLibHdl != nullptr, msg); std::string space = accessor.GetAttrAsString("device"); IAllocator *alloc = info_.GetAllocator(space); workspaceSizeInBytes = diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index 3193bf51d..39c77d5a3 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -37,8 +37,12 @@ using namespace mlir; #define FILE_NAME_ATTR "device_file_name" #define KERNEL_NAME_ATTR "kernel_name" -#define GRID_SIZE_ATTR "GridSize.x" -#define BLOCK_SIZE_ATTR "BlockSize.x" +#define GRID_SIZE_X_ATTR "GridSize.x" +#define GRID_SIZE_Y_ATTR "GridSize.y" +#define GRID_SIZE_Z_ATTR "GridSize.z" +#define BLOCK_SIZE_X_ATTR "BlockSize.x" +#define BLOCK_SIZE_Y_ATTR "BlockSize.y" +#define BLOCK_SIZE_Z_ATTR "BlockSize.z" #define ARG_RANKS_ATTR "arg_ranks" #define CALL_CONVENTION_ATTR "call_convention" @@ -119,29 +123,57 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) impl_->call_convention = "all"; // static assignment for config // TODO extend to support dynamic - if (!info.GetOperation()->hasAttrOfType(GRID_SIZE_ATTR)) { + if (!info.GetOperation()->hasAttrOfType(GRID_SIZE_X_ATTR)) { BRT_THROW_EX(std::runtime_error, "no GridSize.x attr"); } - if (!info.GetOperation()->hasAttrOfType(BLOCK_SIZE_ATTR)) { + if (!info.GetOperation()->hasAttrOfType(BLOCK_SIZE_X_ATTR)) { BRT_THROW_EX(std::runtime_error, "no BlockSize.x attr"); } - if (!info.GetOperation()->hasAttrOfType(ARG_RANKS_ATTR)) { - BRT_THROW_EX(std::runtime_error, "no arg_ranks attr"); + int gx = static_cast(info.GetOperation() + ->getAttrOfType(GRID_SIZE_X_ATTR) + .getInt()), + gy = 1, gz = 1; + if (info.GetOperation()->hasAttrOfType(GRID_SIZE_Y_ATTR)) { + gy = static_cast(info.GetOperation() + ->getAttrOfType(GRID_SIZE_Y_ATTR) + .getInt()); + } + if (info.GetOperation()->hasAttrOfType(GRID_SIZE_Z_ATTR)) { + gz = static_cast(info.GetOperation() + ->getAttrOfType(GRID_SIZE_Z_ATTR) + .getInt()); } - int gx = static_cast( - info.GetOperation()->getAttrOfType(GRID_SIZE_ATTR).getInt()); int bx = static_cast(info.GetOperation() - ->getAttrOfType(BLOCK_SIZE_ATTR) - .getInt()); - std::vector ranks = GetIntArrayAttr( - info.GetOperation()->getAttrOfType(ARG_RANKS_ATTR)); + ->getAttrOfType(BLOCK_SIZE_X_ATTR) + .getInt()), + by = 1, bz = 1; + if (info.GetOperation()->hasAttrOfType(BLOCK_SIZE_Y_ATTR)) { + by = static_cast(info.GetOperation() + ->getAttrOfType(BLOCK_SIZE_Y_ATTR) + .getInt()); + } + if (info.GetOperation()->hasAttrOfType(BLOCK_SIZE_Z_ATTR)) { + bz = static_cast(info.GetOperation() + ->getAttrOfType(BLOCK_SIZE_Z_ATTR) + .getInt()); + } + + std::vector ranks; + if (info.GetOperation()->hasAttrOfType(ARG_RANKS_ATTR)) { + ranks = GetIntArrayAttr( + info.GetOperation()->getAttrOfType(ARG_RANKS_ATTR)); + } else { + for (unsigned int i = 0; i < GetOpArgNum(info_); ++i) { + ranks.push_back(GetRankFromOpArgIndex(info_, i)); + } + } auto num_arg = GetOpArgNum(info_); - impl_->grid = dim3(gx, 1, 1); - impl_->block = dim3(bx, 1, 1); + impl_->grid = dim3(gx, gy, gz); + impl_->block = dim3(bx, by, bz); impl_->shared_size = 0; impl_->arg_reserve_size = 3; // initial 3 for grid/block/shared_size diff --git a/runtime/lib/backends/cuda/providers/default/flash_attn/flash_attn_bwd.cc b/runtime/lib/backends/cuda/providers/default/flash_attn/flash_attn_bwd.cc index 7dfb2baf3..4f9da6ae5 100644 --- a/runtime/lib/backends/cuda/providers/default/flash_attn/flash_attn_bwd.cc +++ b/runtime/lib/backends/cuda/providers/default/flash_attn/flash_attn_bwd.cc @@ -85,10 +85,10 @@ common::Status FlashAttnBwdOpKernel::RunImpl(const ExecutionContext &ctx) { } // dropout check - bool is_dropout = p_dropout > 0.0; - if (is_dropout) { - return InvalidArgs("currently, we only support p_dropout == 0"); - } + // bool is_dropout = p_dropout > 0.0; + // if (is_dropout) { + // return InvalidArgs("currently, we only support p_dropout == 0"); + // } // type check const auto dout_type = accessor.GetArgDTypeEnum(0); @@ -288,6 +288,7 @@ common::Status FlashAttnBwdOpKernel::RunImpl(const ExecutionContext &ctx) { /* seqlen_k */ seqlen_k, /* seqlen_q_rounded */ seqlen_q_rounded, /* seqlen_k_rounded */ seqlen_k_rounded, + /* p_dropout */ p_dropout, /* is_causal */ is_causal, /* stream */ stream); diff --git a/runtime/lib/backends/cuda/providers/default/flash_attn/flash_attn_fwd.cc b/runtime/lib/backends/cuda/providers/default/flash_attn/flash_attn_fwd.cc index 1702623e4..5a588ac82 100644 --- a/runtime/lib/backends/cuda/providers/default/flash_attn/flash_attn_fwd.cc +++ b/runtime/lib/backends/cuda/providers/default/flash_attn/flash_attn_fwd.cc @@ -50,10 +50,16 @@ common::Status FlashAttnFwdOpKernel::RunImpl(const ExecutionContext &ctx) { void *q_ptr = accessor.GetArgAsyncValueRef(0); void *k_ptr = accessor.GetArgAsyncValueRef(1); void *v_ptr = accessor.GetArgAsyncValueRef(2); - void *o_ptr = accessor.GetArgAsyncValueRef(3); - void *softmax_lse_ptr = accessor.GetArgAsyncValueRef(4); - void *softmax_ptr = accessor.GetArgAsyncValueRef(5); - void *rng_state_ptr = accessor.GetArgAsyncValueRef(6); // TODO : handle rng + void *rng_state_ptr = accessor.GetArgAsyncValueRef(3); + void *o_ptr = accessor.GetArgAsyncValueRef(4); + void *softmax_lse_ptr = accessor.GetArgAsyncValueRef(5); + void *softmax_ptr = accessor.GetArgAsyncValueRef(6); + + // check rng_state + // uint64_t *h_rng_state = new uint64_t[2]; + // cudaMemcpy(h_rng_state, rng_state_ptr, 2 * sizeof(uint64_t), + // cudaMemcpyDeviceToHost); std::cout << h_rng_state[0] << "," << + // h_rng_state[1] << std::endl; cudaDeviceSynchronize(); // attr const bool is_causal = accessor.GetAttrAsBool("causal"); @@ -66,7 +72,7 @@ common::Status FlashAttnFwdOpKernel::RunImpl(const ExecutionContext &ctx) { const auto q_shape = accessor.GetArgShape(0); const auto k_shape = accessor.GetArgShape(1); const auto v_shape = accessor.GetArgShape(2); - const auto o_shape = accessor.GetArgShape(3); + const auto o_shape = accessor.GetArgShape(4); int64_t o_rank = o_shape.size(); int64_t q_rank = q_shape.size(); int64_t k_rank = k_shape.size(); @@ -115,7 +121,7 @@ common::Status FlashAttnFwdOpKernel::RunImpl(const ExecutionContext &ctx) { DTypeEnum q_dtype = accessor.GetArgDTypeEnum(0); DTypeEnum k_dtype = accessor.GetArgDTypeEnum(1); DTypeEnum v_dtype = accessor.GetArgDTypeEnum(2); - DTypeEnum o_dtype = accessor.GetArgDTypeEnum(3); + DTypeEnum o_dtype = accessor.GetArgDTypeEnum(4); if (o_dtype != q_dtype || q_dtype != k_dtype || k_dtype != v_dtype) { return InvalidArgs( "query, key, value, and output must have the same dtype"); @@ -194,6 +200,7 @@ common::Status FlashAttnFwdOpKernel::RunImpl(const ExecutionContext &ctx) { /* seqlen_k */ seqlen_k, /* seqlen_q_rounded */ seqlen_q_rounded, /* seqlen_k_rounded */ seqlen_k_rounded, + /* p_dropout */ p_dropout, /* is_causal */ is_causal, /* stream */ stream); diff --git a/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_api.cu b/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_api.cu index e2981b63a..dce323437 100644 --- a/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_api.cu +++ b/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_api.cu @@ -142,7 +142,7 @@ void run_mha(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, uint32_t seqlen_q, uint32_t seqlen_k, uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, - int is_causal, cudaStream_t stream) { + float p_dropout, int is_causal, cudaStream_t stream) { Flash_fwd_params params; // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -187,7 +187,7 @@ void run_mha(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; - params.p_dropout = 1.; // probability to keep + params.p_dropout = 1.f - p_dropout; // probability to keep params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; @@ -195,6 +195,8 @@ void run_mha(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = softmax_ptr; // used for `return_softmax`. + params.rng_state = static_cast(rng_state_ptr); + // print_Flash_fwd_params(params); FP16_SWITCH(!params.is_bf16, [&] { @@ -225,7 +227,7 @@ void run_mha_bwd(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, uint32_t seqlen_q, uint32_t seqlen_k, uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, - int is_causal, cudaStream_t stream) { + float p_dropout, int is_causal, cudaStream_t stream) { Flash_bwd_params params; // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -293,7 +295,7 @@ void run_mha_bwd(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; - params.p_dropout = 1.; // probability to keep + params.p_dropout = 1.f - p_dropout; // probability to keep params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; @@ -302,6 +304,7 @@ void run_mha_bwd(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = nullptr; // used for `return_softmax`, no use in bwd params.dsoftmax_sum = dsoftmax_sum_ptr; + params.rng_state = static_cast(rng_state_ptr); // print_Flash_bwd_params(params); diff --git a/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_api.h b/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_api.h index 205ee88cd..5c6e6f70c 100644 --- a/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_api.h +++ b/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_api.h @@ -29,7 +29,7 @@ void run_mha(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, uint32_t seqlen_q, uint32_t seqlen_k, uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, - int is_causal, cudaStream_t stream); + float p_dropout, int is_causal, cudaStream_t stream); void run_mha_bwd(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *dout_ptr, void *dq_ptr, void *dk_ptr, void *dv_ptr, @@ -53,7 +53,7 @@ void run_mha_bwd(void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, uint32_t seqlen_q, uint32_t seqlen_k, uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, - int is_causal, cudaStream_t stream); + float p_dropout, int is_causal, cudaStream_t stream); } // namespace kernel } // namespace cuda diff --git a/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_bwd_kernel.h b/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_bwd_kernel.h index b7f9d95a4..91bac6590 100644 --- a/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_bwd_kernel.h +++ b/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_bwd_kernel.h @@ -957,8 +957,18 @@ compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, // auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % // 32; - unsigned long long seed = 0; - unsigned long long offset = 0; + // deprecated: no rng support. + // unsigned long long seed = 0; + // unsigned long long offset = 0; + + unsigned long long seed = params.rng_state[0]; + unsigned long long offset = + params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; + + // if (block_id == 0 && tidx == 0) { + // printf("seed:%lu\n",seed); + // printf("offset:%lu\n",offset); + // } clear(acc_dv); clear(acc_dk); @@ -1693,8 +1703,18 @@ compute_dq_dk_dv_1rowblock(const Params ¶ms, const int bidb, const int bidh, // auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % // 32; - unsigned long long seed = 0; - unsigned long long offset = 0; + // deprecated: no rng support. + // unsigned long long seed = 0; + // unsigned long long offset = 0; + + unsigned long long seed = params.rng_state[0]; + unsigned long long offset = + params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; + + // if (block_id == 0 && tidx == 0) { + // printf("seed:%lu\n",seed); + // printf("offset:%lu\n",offset); + // } clear(acc_dq); diff --git a/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_fwd_kernel.h b/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_fwd_kernel.h index e024129f4..d89242b36 100644 --- a/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_fwd_kernel.h +++ b/runtime/lib/backends/cuda/providers/default/flash_attn/kernels/flash_fwd_kernel.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -376,8 +377,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, // unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * // 32 + tidx % 32; - unsigned long long seed = 0; - unsigned long long offset = 0; + // deprecated: no rng support. + // unsigned long long seed = 0; + // unsigned long long offset = 0; + + unsigned long long seed = params.rng_state[0]; + unsigned long long offset = + params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; + + // if (block_id == 0 && tidx == 0) { + // printf("seed:%lu\n",seed); + // printf("offset:%lu\n",offset); + // } // Save seed and offset for backward. // if (block_id == 0 && tidx == 0) { diff --git a/runtime/lib/backends/cuda/providers/default/tensor_generate/fill.cc b/runtime/lib/backends/cuda/providers/default/tensor_generate/fill.cc index 582bc8ea0..e04e7a751 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_generate/fill.cc +++ b/runtime/lib/backends/cuda/providers/default/tensor_generate/fill.cc @@ -40,14 +40,20 @@ common::Status FillOpKernel::RunImpl(const ExecutionContext &ctx) { static_cast(ctx.work_queue)->GetComputeStream(); void *device_p = accessor.GetArgAsyncValueRef(0); size_t length = accessor.GetNumElementsOfShape(accessor.GetArgShape(0)); + // TODO: common helper for dtype dispatch #define CASE(dtype, ctype, mlir_type) \ - case DTypeEnum::dtype: \ - kernel::Fill( \ - stream, static_cast(device_p), \ - static_cast(accessor.GetAttrAsSplatValue("value")), \ - length); \ - return common::Status::OK() + case DTypeEnum::dtype: { \ + if (accessor.HasAttrOfSplatValue("value")) { \ + kernel::Fill(stream, static_cast(device_p), \ + static_cast( \ + accessor.GetAttrAsSplatValue("value")), \ + length); \ + return common::Status::OK(); \ + } \ + break; \ + } + switch (dtype) { CASE(Float32, float, float); CASE(Int64, int64_t, int64_t); @@ -55,11 +61,30 @@ common::Status FillOpKernel::RunImpl(const ExecutionContext &ctx) { CASE(Float16, __half, float); #undef CASE default: - return common::Status(common::StatusCategory::BRT, - common::StatusCode::NOT_IMPLEMENTED, - "not supported dtype"); + break; }; - return common::Status::OK(); + +#define CASE(dtype, ctype) \ + case DTypeEnum::dtype: { \ + std::vector value = accessor.GetAttrAsVector("value"); \ + cudaMemcpyAsync(device_p, value.data(), value.size() * sizeof(ctype), \ + cudaMemcpyHostToDevice, stream); \ + return common::Status::OK(); \ + } + + switch (dtype) { + CASE(Float32, float); + CASE(Int64, int64_t); + CASE(Float64, double); + CASE(Float16, half_float::half); +#undef CASE + default: + break; + }; + + return common::Status(common::StatusCategory::BRT, + common::StatusCode::NOT_IMPLEMENTED, + "not supported FillOp"); } common::Status FillOpKernel::ProloguePerFrame(const ExecutionContext &) { diff --git a/runtime/lib/core/framework/op_accessor.cc b/runtime/lib/core/framework/op_accessor.cc index 18176432e..975f0982e 100644 --- a/runtime/lib/core/framework/op_accessor.cc +++ b/runtime/lib/core/framework/op_accessor.cc @@ -171,6 +171,31 @@ T OpAccessor::GetAttrAsSplatValue(const std::string &name) const { BRT_THROW("Attribute " + name + " is not set"); } +// GetDenseAttrAsVector will iterate every elements in dense attibutes. +// If you want to avoid iterating, consider use getRawData() but special handle +// for i1 ??? +template +std::vector OpAccessor::GetAttrAsVector(const std::string &name) const { + std::vector results; + if (auto attr = + info_.GetOperation()->getAttrOfType(name)) { + results.reserve(attr.size()); + for (APInt &&i : attr) { + results.push_back(static_cast(i.getSExtValue())); + } + return results; + } else if (auto attr = + info_.GetOperation()->getAttrOfType( + name)) { + results.reserve(attr.size()); + for (APFloat &&i : attr) { + results.push_back(static_cast(i.convertToDouble())); + } + return results; + } + BRT_THROW("Attribute " + name + " is not supported to get as vector"); +} + std::string OpAccessor::GetUID() const { auto byre_op = llvm::cast(info_.GetOperation()); return ByREHandle::GetOpUID(byre_op); @@ -211,6 +236,18 @@ INST_ATTR_METH(double) INST_ATTR_METH(StringView) #undef INST_ATTR_METH +#define INST_DENSE_ATTR_METH(T) \ + template std::vector OpAccessor::GetAttrAsVector(const std::string &) \ + const; +INST_DENSE_ATTR_METH(float) +INST_DENSE_ATTR_METH(int32_t) +INST_DENSE_ATTR_METH(int64_t) +INST_DENSE_ATTR_METH(uint8_t) +INST_DENSE_ATTR_METH(uint32_t) +INST_DENSE_ATTR_METH(double) +INST_DENSE_ATTR_METH(half_float::half) +#undef INST_DENSE_ATTR_METH + #define INST_SCALAR_METH(T) \ template T OpAccessor::GetArgScalar(size_t); \ template common::Status OpAccessor::SetResultScalar(size_t result_idx, \ diff --git a/runtime/test/backends/cuda/providers/default/kernel/fill_test.cc b/runtime/test/backends/cuda/providers/default/kernel/fill_test.cc index 6c870d93b..9f8c30335 100644 --- a/runtime/test/backends/cuda/providers/default/kernel/fill_test.cc +++ b/runtime/test/backends/cuda/providers/default/kernel/fill_test.cc @@ -32,8 +32,6 @@ using namespace brt::cuda; using namespace brt::test; TEST(CUDATestFillOp, Basic) { - constexpr size_t length = 512 * 128; - Session session; auto status_allocator = CUDAAllocatorFactory(&session); BRT_TEST_CHECK_STATUS(status_allocator); @@ -54,8 +52,16 @@ TEST(CUDATestFillOp, Basic) { auto status_sync = request->Sync(); BRT_TEST_CHECK_STATUS(status_sync); + size_t length = 512 * 128; CheckCUDAValues(static_cast(request->GetArg(0)), length, 0.f); CheckCUDAValues(static_cast(request->GetArg(1)), length, 1.f); CheckCUDAValues<__half>(static_cast<__half *>(request->GetArg(2)), length, static_cast<__half>(1.f)); + length = 3; + std::vector results = {static_cast(1.f), + static_cast(2.f), + static_cast(3.f)}; + EXPECT_TRUE(CheckCUDAValuesWithCPUValues( + static_cast<__half *>(request->GetArg(3)), + reinterpret_cast<__half *>(results.data()), length)); } diff --git a/runtime/test/backends/cuda/providers/default/kernel/flash_attn_fwd_test.cc b/runtime/test/backends/cuda/providers/default/kernel/flash_attn_fwd_test.cc index f5a2d9a3a..38a0bc87b 100644 --- a/runtime/test/backends/cuda/providers/default/kernel/flash_attn_fwd_test.cc +++ b/runtime/test/backends/cuda/providers/default/kernel/flash_attn_fwd_test.cc @@ -51,6 +51,7 @@ TEST(SM80CUDATestFlashAttnFwd, Basic) { size_t head_dims = 32; size_t input_len = b * seq_len * num_heads * head_dims; size_t softmax_len = b * seq_len * num_heads; + // size_t rng_state_len = 2; Session session; auto status_allocator = CUDAAllocatorFactory(&session); @@ -71,12 +72,22 @@ TEST(SM80CUDATestFlashAttnFwd, Basic) { __half *d_v; float *d_softmax_lse; + // rng_state + // uint64_t *d_rng_state; + // uint64_t h_rng_state[2]; + // h_rng_state[0] = 0UL; + // h_rng_state[1] = 3000UL; + cudaMalloc(&d_o, input_len * sizeof(__half)); cudaMalloc(&d_q, input_len * sizeof(__half)); cudaMalloc(&d_k, input_len * sizeof(__half)); cudaMalloc(&d_v, input_len * sizeof(__half)); cudaMalloc(&d_softmax_lse, softmax_len * sizeof(float)); + // cudaMalloc(&d_rng_state, rng_state_len * sizeof(uint64_t)); + // cudaMemcpy(d_rng_state, h_rng_state, rng_state_len * sizeof(uint64_t), + // cudaMemcpyHostToDevice); + ReadCUDAFloatValues(d_q, input_len, input_q_file); ReadCUDAFloatValues(d_k, input_len, input_k_file); ReadCUDAFloatValues(d_v, input_len, input_v_file); @@ -96,6 +107,7 @@ TEST(SM80CUDATestFlashAttnFwd, Basic) { request->BindArg(2, d_v); request->BindArg(3, d_o); request->BindArg(4, d_softmax_lse); + // request->BindArg(6, d_rng_state); request->FinishIOBinding(); @@ -104,7 +116,7 @@ TEST(SM80CUDATestFlashAttnFwd, Basic) { auto status_sync = request->Sync(); BRT_TEST_CHECK_STATUS(status_sync); - // PrintCUDAValues(d_o, input_len, input_len); + PrintCUDAValues(d_o, input_len, input_len); CheckCUDABuffer<__half>( (__half *)d_o, /* size */ input_len, [&](__half *h_ptr) { diff --git a/runtime/test/include/brt/test/common/cuda/util.h b/runtime/test/include/brt/test/common/cuda/util.h index 6e63144e8..3def1a9af 100644 --- a/runtime/test/include/brt/test/common/cuda/util.h +++ b/runtime/test/include/brt/test/common/cuda/util.h @@ -146,6 +146,19 @@ template return passed; } +template +[[nodiscard]] bool CheckCUDAValuesWithCPUValues(T *first, T *second, + size_t size, + size_t print_count = 10) { + cudaDeviceSynchronize(); + T *h_first = (T *)malloc(size * sizeof(T)); + cudaMemcpy(h_first, first, size * sizeof(T), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + bool passed = CheckCPUValues(h_first, second, size, print_count); + free(h_first); + return passed; +} + // print floating point values template ::value, int> = 0> diff --git a/runtime/test/test_files/fill_cuda.mlir b/runtime/test/test_files/fill_cuda.mlir index 5c25bf916..81b78c438 100644 --- a/runtime/test/test_files/fill_cuda.mlir +++ b/runtime/test/test_files/fill_cuda.mlir @@ -1,10 +1,12 @@ module attributes {byre.container_module} { func.func @test_fill(%arg0 : memref<512x128xf32, "cuda"> {byre.argname = "Fill0", byre.argtype = 2: i32}, %arg1 : memref<512x128xf32, "cuda"> {byre.argname = "Fill1", byre.argtype = 2: i32}, - %arg2 : memref<512x128xf16, "cuda"> {byre.argname = "Fill1FP16", byre.argtype = 2: i32}) attributes {byre.entry_point} { + %arg2 : memref<512x128xf16, "cuda"> {byre.argname = "Fill1FP16", byre.argtype = 2: i32}, + %arg3 : memref<3xf16, "cuda"> {byre.argname = "FillNonSplat", byre.argtype = 2: i32}) attributes {byre.entry_point} { byre.compute @FillOp(%arg0) {value = dense<0.000000e+00> : tensor<512x128xf32>} : memref<512x128xf32, "cuda"> byre.compute @FillOp(%arg1) {value = dense<1.000000e+00> : tensor<512x128xf32>} : memref<512x128xf32, "cuda"> byre.compute @FillOp(%arg2) {value = dense<1.000000e+00> : tensor<512x128xf16>} : memref<512x128xf16, "cuda"> + byre.compute @FillOp(%arg3) {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf16>} : memref<3xf16, "cuda"> return } } \ No newline at end of file diff --git a/runtime/test/test_files/flash_attn_fwd.mlir b/runtime/test/test_files/flash_attn_fwd.mlir index eab8e694a..358b551f9 100644 --- a/runtime/test/test_files/flash_attn_fwd.mlir +++ b/runtime/test/test_files/flash_attn_fwd.mlir @@ -6,7 +6,7 @@ module attributes {byre.container_module} { %arg4 : memref<1x3x128xf32, "cuda"> {byre.argname = "SoftmaxLse", byre.argtype = 2: i32}, %arg5 : memref<1x3x128x128xf32, "cuda"> {byre.argname = "SoftmaxPtr", byre.argtype = 2: i32}, %arg6 : memref<2xi64, "cuda"> {byre.argname = "RngState", byre.argtype = 2: i32}) attributes {byre.entry_point} { - byre.compute @byteir.flash_attn_fwd(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {causal = true, dropout_p = 0.000000e+00 : f32, return_softmax = false, softmax_scale = 0.500000e+00 : f32} : memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<1x3x128x128xf32, "cuda">, memref<2xi64, "cuda"> + byre.compute @byteir.flash_attn_fwd(%arg0, %arg1, %arg2, %arg6, %arg3, %arg4, %arg5) {causal = true, dropout_p = 0.000000e+00 : f32, return_softmax = false, softmax_scale = 0.500000e+00 : f32} : memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<2xi64, "cuda">, memref<1x128x3x32xf16, "cuda">, memref<1x3x128xf32, "cuda">, memref<1x3x128x128xf32, "cuda"> return } -} \ No newline at end of file +} diff --git a/tests/numerical_test/execute.py b/tests/numerical_test/execute.py index 46e67351a..c265fd1d7 100644 --- a/tests/numerical_test/execute.py +++ b/tests/numerical_test/execute.py @@ -91,6 +91,7 @@ def compile_and_run_mlir(mhlo_file, target): interp = Interpreter.load_from_file(mhlo_file) np_inputs = generate_np_inputs(interp) func_name = get_entry_func_name(interp) + unique_name = os.path.basename(mhlo_file).split('.')[0] # run golden golden_outputs = interp.call_function(func_name, np_inputs) @@ -98,8 +99,8 @@ def compile_and_run_mlir(mhlo_file, target): # byteir compile TEMP_FOLDER = "./local_test" os.makedirs(TEMP_FOLDER, exist_ok=True) - os.makedirs(TEMP_FOLDER + f"/{func_name}", exist_ok=True) - output_mlir_file_name = f'{TEMP_FOLDER}/{func_name}/{func_name}.rt.mlir' + os.makedirs(TEMP_FOLDER + f"/{unique_name}", exist_ok=True) + output_mlir_file_name = f'{TEMP_FOLDER}/{unique_name}/{unique_name}.rt.mlir' byteir.compile(mhlo_file, output_mlir_file_name, entry_func=func_name, target=target) except Exception as e: diff --git a/tests/numerical_test/main.py b/tests/numerical_test/main.py index ae2c416e4..ed46a05cd 100644 --- a/tests/numerical_test/main.py +++ b/tests/numerical_test/main.py @@ -31,7 +31,7 @@ parser.add_argument("--target", type=str, default="cuda_with_ait", choices=["ait", "cuda", "cuda_with_ait_aggressive"], help="target device name") parser.add_argument("-c", "--config", default="all", - choices=["all", "mlir", "torch"], help="test sets to run.") + choices=["all", "mlir", "torch", "dynamo"], help="test sets to run.") args = parser.parse_args() EXCLUDE_MLIR_TESTS = [] @@ -40,10 +40,11 @@ SM80_PLUS_TESTS = [ "dot_f32.mlir", + "bmm_rrr_permute_f16.mlir", "bmm_rrr_permute_f32.mlir", "MatmulF32Module_basic", "BatchMatmulAddF32Module_basic", - "BatchMatmulF32Module", + "BatchMatmulF32Module_basic", ] @@ -115,13 +116,18 @@ def main(): if args.config == 'all': results = run_mlir_test(arch) results = results + run_torch_test(arch) + # TODO(zzk): disable flash attn test for now + # run_torch_dynamo_tests(arch) elif args.config == 'mlir': results = run_mlir_test(arch) elif args.config == 'torch': results = run_torch_test(arch) + elif args.config == 'dynamo': + # TODO(zzk): use test infra for dynamo tests + # TODO(zzk): disable flash attn test for now + # run_torch_dynamo_tests(arch) + pass failed = report_results(results) - # TODO(zzk): disable flash attn test for now - # run_torch_dynamo_tests(arch) sys.exit(1 if failed else 0) diff --git a/tests/numerical_test/mlir_tests/ops/bmm_rrr_permute_f16.mlir b/tests/numerical_test/mlir_tests/ops/bmm_rrr_permute_f16.mlir new file mode 100644 index 000000000..3c0cb7a86 --- /dev/null +++ b/tests/numerical_test/mlir_tests/ops/bmm_rrr_permute_f16.mlir @@ -0,0 +1,6 @@ +func.func @bmm_rrr_permute(%arg0: tensor<32x64x64xf16>, %arg1: tensor<32x64x128xf16>) -> tensor<1x64x32x128xf16> { + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot} : (tensor<32x64x64xf16>, tensor<32x64x128xf16>) -> tensor<32x64x128xf16> + %1 = mhlo.reshape %0 : (tensor<32x64x128xf16>) -> tensor<1x32x64x128xf16> + %2 = "mhlo.transpose"(%1) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x32x64x128xf16>) -> tensor<1x64x32x128xf16> + return %2 : tensor<1x64x32x128xf16> +} \ No newline at end of file diff --git a/tests/numerical_test/mlir_tests/ops/concat2.mlir b/tests/numerical_test/mlir_tests/ops/concat2.mlir new file mode 100644 index 000000000..fe7d7779a --- /dev/null +++ b/tests/numerical_test/mlir_tests/ops/concat2.mlir @@ -0,0 +1,6 @@ +func.func @concat2(%arg0: tensor, %arg1: tensor) -> (tensor<2xi64>) { + %0 = mhlo.reshape %arg0 : (tensor) -> tensor<1xi64> + %1 = mhlo.reshape %arg1 : (tensor) -> tensor<1xi64> + %2 = "mhlo.concatenate"(%0, %1) {dimension = 0 : i64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> + return %2 : tensor<2xi64> +} diff --git a/tests/numerical_test/torch_dynamo_e2e_testing/backend.py b/tests/numerical_test/torch_dynamo_e2e_testing/backend.py index 0a7ed28fc..a76663efd 100644 --- a/tests/numerical_test/torch_dynamo_e2e_testing/backend.py +++ b/tests/numerical_test/torch_dynamo_e2e_testing/backend.py @@ -18,7 +18,7 @@ import byteir from torch_frontend import compile -from torch_frontend import list_decomposed_ops, preprocess_fx_graph, fx_replace_attn_pattern, replace_flash_attn +from torch_frontend import list_decomposed_ops, preprocess_fx_graph, fx_replace_attn_pattern, replace_flash_attn, get_none_indices from functorch.compile import aot_module from torch._decomp import get_decompositions @@ -67,22 +67,6 @@ def __call__(self, *inputs): ret_ptr += 1 return results -def get_none_indices(fx_g: torch.fx.GraphModule) -> List[int]: - none_indices = [] - for node in fx_g.graph.nodes: - if node.op == "output": - assert len(node.args) == 1, "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, (list, tuple)): - node_arg = list(node_arg) - node_args_len = len(node_arg) - for i in range(node_args_len): - if node_arg[i] is None: - none_indices.append(i) - break - - return none_indices - def byteir_compile_fx_inner(graph: torch.fx.GraphModule, inputs, is_backward, ban_lst=[]): category = 'backward' if is_backward else 'forward'