diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index ab0d9e185360..7ccfccb4f617 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" @@ -752,6 +753,84 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides( return success(); } +LogicalResult MMAAttr::materializeOperandConcreteShape( + OpBuilder &builder, IREE::GPU::MMAFragment fragment, Value operand, + std::optional> permutation, + SmallVector &reassociations, + RankedTensorType &resultType) const { + OpaqueMmaLayout opaqueLayout = + getOpaqueMFMALayout(operand.getContext(), getIntrinsic().getValue()); + // TODO(Max191): The `getConcreteMFMALayout` function creates some + // `PerDimLayoutAttr` that are not used by this function. This means that + // any pass that uses `materializeOperandConcreteShape` needs to be + // dependent on the VectorExt dialect. Ideally, the `getConcreteMFMALayout` + // function should be refactored so we can reuse the shape information of + // the layout without needing to create any `PerDimLayoutAttr`. + ConcreteMmaLayout layout = + getConcreteMFMALayout(operand.getContext(), getIntrinsic().getValue()); + SmallVector> concreteSizes; + SmallVector opaqueSizes; + switch (fragment) { + case IREE::GPU::MMAFragment::Lhs: { + concreteSizes.push_back(layout.aMLayout.getShapes()); + concreteSizes.push_back(layout.aKLayout.getShapes()); + opaqueSizes.push_back(opaqueLayout.mSize); + opaqueSizes.push_back(opaqueLayout.kSize); + break; + } + case IREE::GPU::MMAFragment::Rhs: { + concreteSizes.push_back(layout.bKLayout.getShapes()); + concreteSizes.push_back(layout.bNLayout.getShapes()); + opaqueSizes.push_back(opaqueLayout.kSize); + opaqueSizes.push_back(opaqueLayout.nSize); + break; + } + case IREE::GPU::MMAFragment::Acc: { + concreteSizes.push_back(layout.cMLayout.getShapes()); + concreteSizes.push_back(layout.cNLayout.getShapes()); + opaqueSizes.push_back(opaqueLayout.mSize); + opaqueSizes.push_back(opaqueLayout.nSize); + break; + } + } + if (permutation.has_value()) { + if (permutation.value().size() != opaqueSizes.size()) { + return failure(); + } + applyPermutationToVector(concreteSizes, permutation.value()); + applyPermutationToVector(opaqueSizes, permutation.value()); + } + + // Inner tile must have sizes matching the opaque layout. + auto operandType = llvm::cast(operand.getType()); + ArrayRef operandShape = operandType.getShape(); + SmallVector innerShape(operandShape.end() - opaqueSizes.size(), + operandShape.end()); + if (!llvm::equal(opaqueSizes, innerShape)) { + return failure(); + } + + // Expand the shape of the inner tile to reflect the MMA thread layout. + SmallVector resultShape(operandShape.begin(), + operandShape.end() - 2); + SmallVector reInds = + llvm::map_to_vector(llvm::seq(resultShape.size()), + [](int64_t idx) -> ReassociationIndices { + return ReassociationIndices({idx}); + }); + int idx = reInds.size(); + for (ArrayRef sizes : concreteSizes) { + resultShape.append(SmallVector(sizes)); + reInds.push_back( + llvm::to_vector(llvm::seq(idx, idx + sizes.size()))); + idx += sizes.size(); + } + + reassociations = reInds; + resultType = operandType.clone(resultShape); + return success(); +} + //===----------------------------------------------------------------------===// // MMA Schedule Attributes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index e421f4ea8e34..6809bc6f3009 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -131,6 +131,7 @@ class IREEGPU_MmaVectorLayoutAttr : "getSubgroupSize", "buildMmaOperation", "populateOperandOffsetsSizesStrides", + "materializeOperandConcreteShape", ]> ]> { let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td index d706154474cb..88c345d97416 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td @@ -127,6 +127,26 @@ def IREEGPU_MmaInterfaceAttr : AttrInterface<"MmaInterfaceAttr"> { return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Constructs the offsets/sizes/strides for extracting the per-thread + slice of the given operand fragment. + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"materializeOperandConcreteShape", + /*args=*/(ins + "::mlir::OpBuilder&":$builder, + "::mlir::iree_compiler::IREE::GPU::MMAFragment":$fragment, + "::mlir::Value":$operand, + "std::optional<::llvm::ArrayRef>":$permutation, + "::llvm::SmallVector<::mlir::SmallVector>&":$reassociations, + "::mlir::RankedTensorType&":$result_type + ), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] + >, ]; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel index 00af941c2721..84a0303b8471 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel @@ -51,6 +51,7 @@ iree_gentbl_cc_library( iree_compiler_cc_library( name = "GPUTransforms", srcs = [ + "ConcretizeMmaShapes.cpp", "DistributeMmaToLanes.cpp", "FuseAndHoistParallelLoops.cpp", "LowerIREEGPUOps.cpp", @@ -69,6 +70,7 @@ iree_compiler_cc_library( ":PassesIncGen", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect", "//compiler/src/iree/compiler/Codegen/Transforms", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt index 7d49b1df7e72..e563b0427d88 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt @@ -45,6 +45,7 @@ iree_cc_library( "Passes.h.inc" "Transforms.h" SRCS + "ConcretizeMmaShapes.cpp" "DistributeMmaToLanes.cpp" "FuseAndHoistParallelLoops.cpp" "LowerIREEGPUOps.cpp" @@ -80,6 +81,7 @@ iree_cc_library( MLIRVectorUtils iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect + iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect iree::compiler::Codegen::Transforms PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp new file mode 100644 index 000000000000..9910840bc694 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp @@ -0,0 +1,154 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" +#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::IREE::GPU { + +#define GEN_PASS_DEF_CONCRETIZEMMASHAPESPASS +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" + +namespace { +struct ConcretizeMmaShapesPass final + : impl::ConcretizeMmaShapesPassBase { + using ConcretizeMmaShapesPassBase::ConcretizeMmaShapesPassBase; + void runOnOperation() override; +}; +} // namespace + +struct ConcretizeMmaOperandShape final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConcretizeMmaOperandShape(MLIRContext *context, MMAFragment fragment) + : OpRewritePattern(context), fragment(fragment) {} + + LogicalResult matchAndRewrite(MultiMmaOp mmaOp, + PatternRewriter &rewriter) const override { + if (!mmaOp.hasTensorSemantics()) { + return failure(); + } + + // Get the right operand and permutation for the `fragment`. + Value operand; + std::optional> permutation; + switch (fragment) { + case MMAFragment::Lhs: + operand = mmaOp.getLhs(); + permutation = mmaOp.getLhsPermutation(); + break; + case MMAFragment::Rhs: + operand = mmaOp.getRhs(); + permutation = mmaOp.getRhsPermutation(); + break; + case MMAFragment::Acc: + operand = mmaOp.getAcc(); + permutation = mmaOp.getAccPermutation(); + break; + } + + // Get the reassociation indices and result type of the expand_shape op. + MmaInterfaceAttr kind = mmaOp.getKind(); + SmallVector reassociations; + RankedTensorType concreteType; + if (failed(kind.materializeOperandConcreteShape(rewriter, fragment, operand, + permutation, reassociations, + concreteType))) { + return failure(); + } + + // Create the expand_shape. + Location loc = mmaOp->getLoc(); + Value concreteOperand = rewriter + .create( + loc, concreteType, operand, reassociations) + .getResult(); + + // Expand the permutation for the new inner dimensions of the expanded + // multi_mma operand. + auto expandPerm = + [&](std::optional> perm, MMAFragment frag, + int64_t outerRank) -> std::optional { + if (!perm.has_value()) { + return std::nullopt; + } + if (frag != fragment) { + return rewriter.getDenseI64ArrayAttr(perm.value()); + } + SmallVector innerReInds( + reassociations.begin() + outerRank, reassociations.end()); + for (auto &reInd : innerReInds) { + for (auto &idx : reInd) { + idx -= outerRank; + } + } + SmallVector expandedPerm; + for (auto reInd : applyPermutation(innerReInds, perm.value())) { + expandedPerm.append(reInd); + } + return rewriter.getDenseI64ArrayAttr(expandedPerm); + }; + std::optional lhsPerm = expandPerm( + mmaOp.getLhsPermutation(), MMAFragment::Lhs, mmaOp.getLhsOuterRank()); + std::optional rhsPerm = expandPerm( + mmaOp.getRhsPermutation(), MMAFragment::Rhs, mmaOp.getRhsOuterRank()); + std::optional accPerm = expandPerm( + mmaOp.getAccPermutation(), MMAFragment::Acc, mmaOp.getAccOuterRank()); + + // Create the new multi_mma op with the concrete type. + auto concreteMmaOp = rewriter.create( + loc, + /*lhs=*/fragment == MMAFragment::Lhs ? concreteOperand : mmaOp.getLhs(), + /*rhs=*/fragment == MMAFragment::Rhs ? concreteOperand : mmaOp.getRhs(), + /*acc=*/fragment == MMAFragment::Acc ? concreteOperand : mmaOp.getAcc(), + mmaOp.getIndexingMaps(), mmaOp.getIteratorTypes(), mmaOp.getKind(), + lhsPerm, rhsPerm, accPerm); + + if (auto config = getLoweringConfig(mmaOp)) { + setLoweringConfig(concreteMmaOp, config); + } + + if (fragment != MMAFragment::Acc) { + rewriter.replaceOp(mmaOp, concreteMmaOp); + return success(); + } + + // For the Acc operand, the result needs to be collapsed back to the + // original type so that types match with consumers. + rewriter.replaceOpWithNewOp( + mmaOp, mmaOp.getAccType(), concreteMmaOp.getResult(), reassociations); + + return success(); + } + +private: + MMAFragment fragment; +}; + +void ConcretizeMmaShapesPass::runOnOperation() { + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + RewritePatternSet patterns(context); + if (concretizeInputs) { + patterns.insert(context, MMAFragment::Lhs); + patterns.insert(context, MMAFragment::Rhs); + } + if (concretizeResult) { + patterns.insert(context, MMAFragment::Acc); + } + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td index 9f64563ef0ad..6cc7f11e6f74 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td @@ -20,6 +20,24 @@ def DistributeMmaToLanesPass : ]; } +def ConcretizeMmaShapesPass : + InterfacePass<"iree-gpu-concretize-mma-shapes", "mlir::FunctionOpInterface"> { + let summary = "Expands the inner dimensions of iree_gpu.multi_mma ops to match the thread layout"; + let dependentDialects = [ + "::mlir::tensor::TensorDialect", + "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect", + "::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect", + ]; + let options = [ + Option<"concretizeInputs", "concretize-inputs", + "bool", /*default=*/"true", + "Expand the inner dimensions for the lhs and rhs operands of the multi_mma ops.">, + Option<"concretizeResult", "concretize-result", + "bool", /*default=*/"true", + "Expand the inner dimensions for the acc operand of the multi_mma ops.">, + ]; +} + def FuseAndHoistParallelLoopsPass : InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> { let summary = "Greedily fuses and hoists parallel loops."; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel index 5e9c0e76f2d6..8348d9fb4073 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "concretize_mma_shapes.mlir", "distribute_mma_to_lanes.mlir", "fuse_and_hoist_forall.mlir", "pack_to_intrinsics.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt index ef55e3d7faea..a71fd9ce61d5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "concretize_mma_shapes.mlir" "distribute_mma_to_lanes.mlir" "fuse_and_hoist_forall.mlir" "pack_to_intrinsics.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir new file mode 100644 index 000000000000..990bfea08d6d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir @@ -0,0 +1,110 @@ +// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-concretize-mma-shapes{concretize-result=false}, canonicalize, cse))' --split-input-file | FileCheck %s -check-prefixes=CHECK,CHECK-INPUTS +// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-concretize-mma-shapes{concretize-inputs=false}, canonicalize, cse))' --split-input-file | FileCheck %s -check-prefixes=CHECK,CHECK-RESULT + +#contraction_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}> +func.func @concretize_multi_mma_F32_16x16x16_F16(%lhs: tensor<2x2x16x16xf16>, %rhs: tensor<2x2x16x16xf16>, %acc: tensor<2x2x16x16xf32>) -> tensor<2x2x16x16xf32> { + %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { + indexing_maps = #contraction_accesses, + iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type], + kind = #iree_gpu.mma_layout, lowering_config = #config + } : tensor<2x2x16x16xf16>, tensor<2x2x16x16xf16> into tensor<2x2x16x16xf32> + return %0 : tensor<2x2x16x16xf32> +} + +// CHECK-LABEL: func @concretize_multi_mma_F32_16x16x16_F16 +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x16x16xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x16x16xf16> +// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xf32> + +// CHECK-INPUTS-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 4] : tensor<2x2x16x16xf16> into tensor<2x2x16x4x4xf16> +// CHECK-INPUTS-DAG: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xf16> into tensor<2x2x4x4x16xf16> +// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]] +// CHECK-INPUTS-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-INPUTS-SAME: : tensor<2x2x16x4x4xf16>, tensor<2x2x4x4x16xf16> into tensor<2x2x16x16xf32> +// CHECK-INPUTS: return %[[MMA]] + +// CHECK-RESULT-DAG: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xf32> into tensor<2x2x4x4x16xf32> +// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]] +// CHECK-RESULT-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-RESULT-SAME: : tensor<2x2x16x16xf16>, tensor<2x2x16x16xf16> into tensor<2x2x4x4x16xf32> +// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3], [4]] : tensor<2x2x4x4x16xf32> into tensor<2x2x16x16xf32> +// CHECK-RESULT: return %[[COLLAPSED]] + +// ----- + +#contraction_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (j, k)>, + affine_map<(i, j, k) -> (i, j)> +] +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}> +func.func @concretize_multi_mma_I32_16x16x32_I8(%lhs: tensor<2x2x16x32xi8>, %rhs: tensor<2x2x16x32xi8>, %acc: tensor<2x2x16x16xi32>) -> tensor<2x2x16x16xi32> { + %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { + indexing_maps = #contraction_accesses, + iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type], + kind = #iree_gpu.mma_layout, + rhs_permutation = array, lowering_config = #config + } : tensor<2x2x16x32xi8>, tensor<2x2x16x32xi8> into tensor<2x2x16x16xi32> + return %0 : tensor<2x2x16x16xi32> +} + +// CHECK-LABEL: func @concretize_multi_mma_I32_16x16x32_I8 +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x16x32xi8> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x16x32xi8> +// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xi32> + +// CHECK-INPUTS-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 8] : tensor<2x2x16x32xi8> into tensor<2x2x16x4x8xi8> +// CHECK-INPUTS-DAG: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 8] : tensor<2x2x16x32xi8> into tensor<2x2x16x4x8xi8> +// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]] +// CHECK-INPUTS-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-INPUTS-SAME: rhs_permutation = array +// CHECK-INPUTS-SAME: : tensor<2x2x16x4x8xi8>, tensor<2x2x16x4x8xi8> into tensor<2x2x16x16xi32> +// CHECK-INPUTS: return %[[MMA]] + +// CHECK-RESULT-DAG: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xi32> into tensor<2x2x4x4x16xi32> +// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]] +// CHECK-RESULT-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-RESULT-SAME: : tensor<2x2x16x32xi8>, tensor<2x2x16x32xi8> into tensor<2x2x4x4x16xi32> +// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3], [4]] : tensor<2x2x4x4x16xi32> into tensor<2x2x16x16xi32> +// CHECK-RESULT: return %[[COLLAPSED]] + +// ----- + +#contraction_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}> +func.func @concretize_multi_mma_F32_32x32x8_F16(%lhs: tensor<2x2x32x8xf16>, %rhs: tensor<2x2x8x32xf16>, %acc: tensor<2x2x32x32xf32>) -> tensor<2x2x32x32xf32> { + %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { + indexing_maps = #contraction_accesses, + iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type], + kind = #iree_gpu.mma_layout, lowering_config = #config + } : tensor<2x2x32x8xf16>, tensor<2x2x8x32xf16> into tensor<2x2x32x32xf32> + return %0 : tensor<2x2x32x32xf32> +} + +// CHECK-LABEL: func @concretize_multi_mma_F32_32x32x8_F16 +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x32x8xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x8x32xf16> +// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x32x32xf32> + +// CHECK-INPUTS-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 32, 2, 4] : tensor<2x2x32x8xf16> into tensor<2x2x32x2x4xf16> +// CHECK-INPUTS-DAG: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 2, 4, 32] : tensor<2x2x8x32xf16> into tensor<2x2x2x4x32xf16> +// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]] +// CHECK-INPUTS-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-INPUTS-SAME: : tensor<2x2x32x2x4xf16>, tensor<2x2x2x4x32xf16> into tensor<2x2x32x32xf32> +// CHECK-INPUTS: return %[[MMA]] + +// CHECK-RESULT-DAG: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3, 4], [5]] output_shape [2, 2, 4, 2, 4, 32] : tensor<2x2x32x32xf32> into tensor<2x2x4x2x4x32xf32> +// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]] +// CHECK-RESULT-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-RESULT-SAME: : tensor<2x2x32x8xf16>, tensor<2x2x8x32xf16> into tensor<2x2x4x2x4x32xf32> +// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3, 4], [5]] : tensor<2x2x4x2x4x32xf32> into tensor<2x2x32x32xf32> +// CHECK-RESULT: return %[[COLLAPSED]]