diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index ffe82da50cb8..554ecbfcc5de 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -290,8 +290,9 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, auto aKLayout = inner; auto bKLayout = inner; auto bNLayout = outer; - auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {8, 2}); - auto cNLayout = outer; + auto cMLayout = + PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {8, 2, 1}); + auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16}); return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, bNLayout, cMLayout, cNLayout}; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 17492ea21cea..94e8a16c2229 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -1647,6 +1647,7 @@ transform_dialect::SetContractionLayoutAttributes::apply( << "invalid opaque mma layout for annotation " << mmaType; } + contract->setAttr("iree.amdgpu.mma", mmaType); auto [aLayout, bLayout, cLayout] = *maybeLayouts; contract->setAttr("__vector_layout_test_anchor_operand_0", aLayout); contract->setAttr("__vector_layout_test_anchor_operand_1", bLayout); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp index 39cdd7cc473c..6253dd783c15 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp @@ -7,6 +7,7 @@ #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h" #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" @@ -18,83 +19,6 @@ using VectorValue = TypedValue; enum class ContractMatrixType { A, B, C, D }; enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED }; -/// We define AMD MFMA instruction layouts only for contract type MM, i.e. C(i, -/// k) += A(i, j) * B(j, k). We call these canonical layouts: layoutA, layoutB, -/// layoutC, corresponding to the A, B, C matrices. -/// -/// For any other contract type, the layout is simply transposed for that -/// operand. For example, for MMT, the layouts used should be layoutA, -/// layoutB.T, layoutC. For an easier understanding of this transposition, -/// think of the transpose simply being outside the contract: -/// -/// vector.contract {type = MMT} %a, %b -/// -/// is equivalent to -/// -/// %bt = vector.transpose %b -/// vector.contract {type = MM} %a, %bt -/// -/// Now, you would assign layouts based on contract type MM, and would get the -/// right layout for %b by transposing the layout for B. -/// -/// Now that we have defined what layoutA, layoutB, layoutC are, we will define -/// what the canonical layouts are for each MFMA instruction. These are -/// represented as the original matrix, with elements representing which thread -/// id in the subgroup gets which element. -/// These layouts were referenced from -/// https://github.com/ROCm/amd_matrix_instruction_calculator -/// -/// The naming scheme for these operators is InputType_MxNxK_OutputType. -enum class MFMAType { - /// layoutA: - /// 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 - /// 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 - /// 2 2 2 2 18 18 18 18 34 34 34 34 50 50 50 50 - /// ... - /// 15 15 15 15 31 31 31 31 47 47 47 47 63 63 63 63 - /// - /// layoutB: - /// Transpose of layoutA - /// - /// layoutC: - /// Same as layoutB - F16_16x16x16_F32, - /// layoutA: - /// 0 0 0 0 32 32 32 32 - /// 1 1 1 1 33 33 33 33 - /// 2 2 2 2 34 34 34 34 - /// ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ - /// 31 31 31 31 63 63 63 63 - /// - /// layoutB: - /// Transpose of layoutA - /// - /// layoutC: - /// 0 1 2 ... 31 - /// 0 1 2 ... 31 - /// 0 1 2 ... 31 - /// 0 1 2 ... 31 - /// 32 33 34 ... 63 - /// 32 33 34 ... 63 - /// 32 33 34 ... 63 - /// 32 33 34 ... 63 - /// 0 1 2 ... 31 - /// ⋮ ⋮ ⋮ ... ⋮ - /// 32 33 34 ... 63 - /// 0 1 2 ... 31 - /// ⋮ ⋮ ⋮ ... ⋮ - /// 32 33 34 ... 63 - /// 0 1 2 ... 31 - /// 0 1 2 ... 31 - /// 0 1 2 ... 31 - /// 0 1 2 ... 31 - /// 32 33 34 ... 63 - /// 32 33 34 ... 63 - /// 32 33 34 ... 63 - /// 32 33 34 ... 63 - F16_32x32x8_F32, -}; - namespace { static bool isOperandATransposed(ContractType contractType) { @@ -166,208 +90,6 @@ struct DistributeContractions final return ContractType::UNSUPPORTED; } - Value computeMMA(Value a, Value b, Value c, Location loc, OpBuilder &rewriter, - MFMAType mfmaType) const { - uint32_t m, n, k, blks; - if (mfmaType == MFMAType::F16_16x16x16_F32) { - m = n = k = 16; - } else if (mfmaType == MFMAType::F16_32x32x8_F32) { - m = n = 32; - k = 8; - } - blks = 1; - return rewriter.create(loc, c.getType(), m, n, k, blks, a, - b, c); - } - - PerDimLayoutAttr createPerDimLayout(MLIRContext *ctx, - ArrayRef dims, - ArrayRef shapes) const { - SmallVector dimAttrs; - for (auto dim : dims) - dimAttrs.push_back(LayoutDimensionAttr::get(ctx, dim)); - return PerDimLayoutAttr::get(ctx, dimAttrs, shapes); - } - - std::tuple createCanonicalLayouts16x16x16( - LayoutDimension batchRowLabel, int64_t batchRow, - LayoutDimension batchColLabel, int64_t batchCol) const { - MLIRContext *ctx = getContext(); - PerDimLayoutAttr rowLayout = createPerDimLayout( - ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 16}); - PerDimLayoutAttr colLayout = createPerDimLayout( - ctx, {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX}, - {batchCol, 4, 4}); - return {rowLayout, colLayout}; - } - - bool isCompatible16x16x16A(LayoutAttr layout, int64_t batchRow, - int64_t batchCol) const { - auto [rowLayout, colLayout] = createCanonicalLayouts16x16x16( - LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol); - LayoutAttr canonicalLayout = - LayoutAttr::get(getContext(), {rowLayout, colLayout}); - return layout == canonicalLayout; - } - - bool isCompatible16x16x16B(LayoutAttr layout, int64_t batchRow, - int64_t batchCol) const { - auto [colLayout, rowLayout] = createCanonicalLayouts16x16x16( - LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow); - LayoutAttr canonicalLayout = - LayoutAttr::get(getContext(), {rowLayout, colLayout}); - return layout == canonicalLayout; - } - - bool isCompatible16x16x16C(LayoutAttr layout, int64_t batchRow, - int64_t batchCol) const { - return isCompatible16x16x16B(layout, batchRow, batchCol); - } - - std::tuple - createCanonicalLayouts32x32x8(LayoutDimension batchRowLabel, int64_t batchRow, - LayoutDimension batchColLabel, int64_t batchCol, - ContractMatrixType matrixType) const { - MLIRContext *ctx = getContext(); - PerDimLayoutAttr rowLayout = createPerDimLayout( - ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 32}); - PerDimLayoutAttr colLayout; - if (matrixType == ContractMatrixType::C) { - colLayout = - createPerDimLayout(ctx, - {batchColLabel, LayoutDimension::VECTORY, - LayoutDimension::LANEY, LayoutDimension::VECTORX}, - {batchCol, 4, 2, 4}); - } else { - colLayout = createPerDimLayout( - ctx, - {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX}, - {batchCol, 2, 4}); - } - return {rowLayout, colLayout}; - } - - bool isCompatible32x32x8A(LayoutAttr layout, int64_t batchRow, - int64_t batchCol) const { - auto [rowLayout, colLayout] = createCanonicalLayouts32x32x8( - LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol, - ContractMatrixType::A); - LayoutAttr canonicalLayout = - LayoutAttr::get(getContext(), {rowLayout, colLayout}); - return layout == canonicalLayout; - } - - bool isCompatible32x32x8B(LayoutAttr layout, int64_t batchRow, - int64_t batchCol) const { - auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8( - LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow, - ContractMatrixType::B); - LayoutAttr canonicalLayout = - LayoutAttr::get(getContext(), {rowLayout, colLayout}); - return layout == canonicalLayout; - } - - bool isCompatible32x32x8C(LayoutAttr layout, int64_t batchRow, - int64_t batchCol) const { - auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8( - LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow, - ContractMatrixType::C); - LayoutAttr canonicalLayout = - LayoutAttr::get(getContext(), {rowLayout, colLayout}); - return layout == canonicalLayout; - } - - bool isCompatible16x16x16(LayoutAttr layout, ContractMatrixType matrixType, - int64_t batchRow, int64_t batchCol) const { - switch (matrixType) { - case ContractMatrixType::A: - return isCompatible16x16x16A(layout, batchRow, batchCol); - case ContractMatrixType::B: - return isCompatible16x16x16B(layout, batchRow, batchCol); - default: - return isCompatible16x16x16C(layout, batchRow, batchCol); - } - return false; - } - - bool isCompatible32x32x8(LayoutAttr layout, ContractMatrixType matrixType, - int64_t batchRow, int64_t batchCol) const { - switch (matrixType) { - case ContractMatrixType::A: - return isCompatible32x32x8A(layout, batchRow, batchCol); - case ContractMatrixType::B: - return isCompatible32x32x8B(layout, batchRow, batchCol); - default: - return isCompatible32x32x8C(layout, batchRow, batchCol); - } - return false; - } - - bool isCompatible(LayoutAttr layout, ContractMatrixType matrixType, - MFMAType mfmaType) const { - std::optional batchRow = layout.getBatchDim(0); - if (!batchRow) - return false; - std::optional batchCol = layout.getBatchDim(1); - if (!batchCol) - return false; - switch (mfmaType) { - case MFMAType::F16_16x16x16_F32: - return isCompatible16x16x16(layout, matrixType, batchRow.value(), - batchCol.value()); - case MFMAType::F16_32x32x8_F32: - return isCompatible32x32x8(layout, matrixType, batchRow.value(), - batchCol.value()); - default: - return false; - } - return false; - } - - // If we have a prior guess of the MFMA type, only evaluate that type. - // Otherwise, evaluate all types to find a match. - std::optional inferMFMAType(LayoutAttr layout, - ContractMatrixType matrixType, - std::optional prior) const { - SmallVector mfmaTypes; - if (prior) { - mfmaTypes.push_back(prior.value()); - } else { - mfmaTypes = {MFMAType::F16_16x16x16_F32, MFMAType::F16_32x32x8_F32}; - } - for (MFMAType mfmaType : mfmaTypes) { - if (isCompatible(layout, matrixType, mfmaType)) - return mfmaType; - } - return std::nullopt; - } - - // Inputs are LHS, RHS and ACC operands and corresponding layouts. - // Output is inferred MFMAType or none (if layout is not compatible with any - // MFMA layout). - std::optional - inferCompatibleMFMAType(ArrayRef layouts, - ContractType contractType) const { - std::optional mfmaType{std::nullopt}; - SmallVector matrixTypes{ - ContractMatrixType::A, ContractMatrixType::B, ContractMatrixType::C}; - - // Canonical layouts for MFMA are transposes of each other. - if (isOperandATransposed(contractType)) { - matrixTypes[0] = ContractMatrixType::B; - } - if (isOperandBTransposed(contractType)) { - matrixTypes[1] = ContractMatrixType::A; - } - - for (auto [layout, matrixType] : llvm::zip(layouts, matrixTypes)) { - mfmaType = inferMFMAType(layout, matrixType, mfmaType); - if (!mfmaType) - return std::nullopt; - } - return mfmaType; - } - LogicalResult matchAndRewrite(vector::ContractionOp contractOp, DistributionSignature &signature, PatternRewriter &rewriter) const override { @@ -408,10 +130,12 @@ struct DistributeContractions final if (contractType == ContractType::UNSUPPORTED) return failure(); - std::optional mfmaType = - inferCompatibleMFMAType(layouts, contractType); - if (!mfmaType) - return failure(); + auto mmaAttr = + contractOp->getAttrOfType("iree.amdgpu.mma"); + if (!mmaAttr) { + return rewriter.notifyMatchFailure( + contractOp, "missing iree.amdgpu.mma intrinsic attribute"); + } std::optional rowBatch = layouts[LHS].getBatchDim(0); if (!rowBatch) @@ -434,10 +158,13 @@ struct DistributeContractions final Value bMatrix = rewriter.create( loc, getDistributed(rewriter, operands[RHS], layouts[RHS]), getIndices(contractType, ContractMatrixType::B, k, indices[1])); - dMatrix = computeMMA(aMatrix, bMatrix, dMatrix, loc, rewriter, - mfmaType.value()); + dMatrix = mmaAttr + .buildMmaOperation(rewriter, loc, dMatrix.getType(), + aMatrix, bMatrix, dMatrix) + .value(); } vector = rewriter.create(loc, dMatrix, vector, indices); + return success(); }; LayoutIterator iterator(resultLayout); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel index c5f7fd225bdd..695ecb0eb6e8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel @@ -29,6 +29,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Common:VectorLayoutAnalysis", "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses", + "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "//compiler/src/iree/compiler/Codegen/Transforms", "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/HAL/IR", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt index b94f2d6a0f36..49a9e7eb7e91 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt @@ -41,6 +41,7 @@ iree_cc_library( iree::compiler::Codegen::Common iree::compiler::Codegen::Common::GPU::CommonGPUPasses iree::compiler::Codegen::Common::VectorLayoutAnalysis + iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect iree::compiler::Codegen::Transforms iree::compiler::Codegen::Utils iree::compiler::Dialect::HAL::IR diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index ec04bc8d5389..1757b5ce48f4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -20,6 +20,7 @@ iree_lit_test_suite( [ "amdgpu_chained_matmul.mlir", "amdgpu_contraction_distribution.mlir", + "amdgpu_set_anchor_layouts.mlir", "attention.mlir", "attention_mfma.mlir", "conv_pipeline_test_cuda.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index be60c9716a3b..2ff84aa75ea2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "amdgpu_chained_matmul.mlir" "amdgpu_contraction_distribution.mlir" + "amdgpu_set_anchor_layouts.mlir" "attention.mlir" "attention_mfma.mlir" "cast_address_space_function.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir index b46a8d05f2ab..fd1241393bb5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir @@ -4,6 +4,8 @@ // layoutC means and how these layouts are assigned based on the instruction // type. +#layout = #iree_gpu.mma_layout + #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> #map2 = affine_map<(d0, d1, d2) -> (d0, d2)> #map3 = affine_map<(d0, d1, d2) -> (d1, d0)> @@ -22,15 +24,6 @@ #layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2> builtin.module attributes { transform.with_named_sequence } { func.func @distribute_mfma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> { - // CHECK-LABEL: distribute_mfma_16x16x16_mmt - // CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32> - // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x4xf32> - // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<4xf32> from vector<1x1x4xf32> - // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x4xf16> - // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16> - // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x4xf16> - // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16> - // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind, "__vector_layout_test_anchor_operand_0" = #layout_a, @@ -42,14 +35,31 @@ builtin.module attributes { transform.with_named_sequence } { return %output : vector<16x16xf32> } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout16x16x16 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op transform.yield } } +// CHECK-LABEL: distribute_mfma_16x16x16_mmt + +// CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32> +// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x4xf32> +// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<4xf32> from vector<1x1x4xf32> +// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x4xf16> +// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16> +// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x4xf16> +// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16> +// CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + // ----- +#layout = #iree_gpu.mma_layout + #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> #map2 = affine_map<(d0, d1, d2) -> (d0, d2)> #map3 = affine_map<(d0, d1, d2) -> (d1, d0)> @@ -70,8 +80,6 @@ builtin.module attributes { transform.with_named_sequence } { #layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3> builtin.module attributes { transform.with_named_sequence } { func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> { - // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch - // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32> %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind, "__vector_layout_test_anchor_operand_0" = #layout_a, @@ -83,14 +91,24 @@ builtin.module attributes { transform.with_named_sequence } { return %output : vector<32x64xf32> } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout16x16x16 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op transform.yield } } +// CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch + +// CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32> + // ----- +#layout = #iree_gpu.mma_layout + #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d2, d1)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -111,15 +129,6 @@ builtin.module attributes { transform.with_named_sequence } { #layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2> builtin.module attributes { transform.with_named_sequence } { func.func @distribute_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> { - // CHECK-LABEL: distribute_mfma_32x32x8_mm - // CHECK-SAME: %[[ARG0:.+]]: vector<32x8xf16>, %[[ARG1:.+]]: vector<8x32xf16>, %[[ARG2:.+]]: vector<32x32xf32> - // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<32x32xf32> -> vector<1x1x16xf32> - // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<16xf32> from vector<1x1x16xf32> - // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<32x8xf16> -> vector<1x1x4xf16> - // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16> - // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<8x32xf16> -> vector<1x1x4xf16> - // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16> - // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32> %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind, "__vector_layout_test_anchor_operand_0" = #layout_a, @@ -131,14 +140,31 @@ builtin.module attributes { transform.with_named_sequence } { return %output : vector<32x32xf32> } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout32x32x8 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout32x32x8 : !transform.any_op, !transform.any_param + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op transform.yield } } +// CHECK-LABEL: distribute_mfma_32x32x8_mm + +// CHECK-SAME: %[[ARG0:.+]]: vector<32x8xf16>, %[[ARG1:.+]]: vector<8x32xf16>, %[[ARG2:.+]]: vector<32x32xf32> +// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<32x32xf32> -> vector<1x1x16xf32> +// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<16xf32> from vector<1x1x16xf32> +// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<32x8xf16> -> vector<1x1x4xf16> +// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16> +// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<8x32xf16> -> vector<1x1x4xf16> +// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16> +// CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32> + // ----- +#layout = #iree_gpu.mma_layout + #map1 = affine_map<(d0, d1, d2) -> (d2, d0)> #map2 = affine_map<(d0, d1, d2) -> (d2, d1)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -160,7 +186,6 @@ builtin.module attributes { transform.with_named_sequence } { #layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2> builtin.module attributes { transform.with_named_sequence } { func.func @distribute_mfma_32x32x8_mtm(%a : vector<8x64xf16>, %b : vector<8x32xf16>, %c : vector<64x32xf32>) -> vector<64x32xf32> { - // CHECK-LABEL: distribute_mfma_32x32x8_mtm %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind, "__vector_layout_test_anchor_operand_0" = #layout_a, @@ -169,17 +194,126 @@ builtin.module attributes { transform.with_named_sequence } { "__vector_layout_test_anchor_result_0" = #layout_c } %a, %b, %c : vector<8x64xf16>, vector<8x32xf16> into vector<64x32xf32> - // CHECK-DAG: %[[A1:.+]] = vector.extract %[[A:.+]][0, 0] : vector<4xf16> from vector<1x2x4xf16> - // CHECK-DAG: %[[B1:.+]] = vector.extract %[[B:.+]][0, 0] : vector<4xf16> from vector<1x1x4xf16> - // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A1]] * %[[B1]] - // CHECK-DAG: %[[A2:.+]] = vector.extract %[[A]][0, 1] : vector<4xf16> from vector<1x2x4xf16> - // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A2]] * %[[B1]] - // CHECK-NOT: amdgpu.mfma return %output : vector<64x32xf32> } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout32x32x8 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout32x32x8 : !transform.any_op, !transform.any_param + + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: distribute_mfma_32x32x8_mtm + +// CHECK-DAG: %[[A1:.+]] = vector.extract %[[A:.+]][0, 0] : vector<4xf16> from vector<1x2x4xf16> +// CHECK-DAG: %[[B1:.+]] = vector.extract %[[B:.+]][0, 0] : vector<4xf16> from vector<1x1x4xf16> +// CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A1]] * %[[B1]] +// CHECK-DAG: %[[A2:.+]] = vector.extract %[[A]][0, 1] : vector<4xf16> from vector<1x2x4xf16> +// CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A2]] * %[[B1]] +// CHECK-NOT: amdgpu.mfma + +// ----- + +#layout = #iree_gpu.mma_layout +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0)> + +// A: vector<16x16>, layout = layoutA +#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 16]> +#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 1, 16]> +#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout> + +// B: vector<16x16>, layout = transpose(layoutB) = layoutA +// Since shapes are also same, we can use the same layout attribute, layout_a. + +// C: vector<16x16>, layout = layoutC +#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [1, 8, 2, 1]> +#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 16]> +#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2> +builtin.module attributes { transform.with_named_sequence } { + func.func @distribute_wmma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> { + %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind, + "__vector_layout_test_anchor_operand_0" = #layout_a, + "__vector_layout_test_anchor_operand_1" = #layout_a, + "__vector_layout_test_anchor_operand_2" = #layout_c, + "__vector_layout_test_anchor_result_0" = #layout_c + } + %a, %b, %c : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32> + return %output : vector<16x16xf32> + } + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout16x16x16 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param + + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: distribute_wmma_16x16x16_mmt + +// CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32> +// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x8xf32> +// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<8xf32> from vector<1x1x8xf32> +// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x16xf16> +// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<16xf16> from vector<1x1x16xf16> +// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x16xf16> +// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<16xf16> from vector<1x1x16xf16> +// CHECK-DAG: %[[OUT:.+]] = amdgpu.wmma %[[AV]] * %[[BV]] + %[[CV]] : vector<16xf16>, vector<16xf16>, vector<8xf32> + +// ----- + +#layout = #iree_gpu.mma_layout + +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0)> + +// A: vector<32x128>, layout = layoutA +#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]> +#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 1, 16]> +#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout> + +// B: vector<64x128>, layout = transpose(layoutB) = layoutA +#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [4, 16]> +#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 1, 16]> +#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2> + +// C: vector<32x64>, layout = layoutC +#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [2, 8, 2, 1]> +#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]> +#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3> +builtin.module attributes { transform.with_named_sequence } { + func.func @distribute_wmma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> { + %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind, + "__vector_layout_test_anchor_operand_0" = #layout_a, + "__vector_layout_test_anchor_operand_1" = #layout_b, + "__vector_layout_test_anchor_operand_2" = #layout_c, + "__vector_layout_test_anchor_result_0" = #layout_c + } + %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32> + return %output : vector<32x64xf32> + } + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout16x16x16 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op transform.yield } } + +// CHECK-LABEL: distribute_wmma_16x16x16_mmt_batch + +// CHECK-COUNT-64: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir new file mode 100644 index 000000000000..2269dc68d15a --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir @@ -0,0 +1,95 @@ +// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --cse %s --verify-diagnostics + +// This tests that the compiler is setting the correct layout anchors for various vectorOps and shapes. +// Currently only testing on contraction layoutV1, but can be expanded to others. + +#layout = #iree_gpu.mma_layout +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0)> + +builtin.module attributes { transform.with_named_sequence } { + func.func @anchor_mfma_16x16x16_mmt(%a : memref<16x16xf16>, %b : memref<16x16xf16>, %init : vector<16x16xf32>) -> vector<16x16xf32> { + // CHECK-LABEL: anchor_mfma_16x16x16_mmt + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.0 : f16 + %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 4, 4]>>}} + %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 4, 4]>>}} + %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEY, VECTORX], [1, 4, 4]>, <[ BATCHY, LANEX], [1, 16]>>}} + return %output : vector<16x16xf32> + } + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout16x16x16 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param + + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op + transform.yield + } +} + +// ----- + +#layout = #iree_gpu.mma_layout +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0)> + +builtin.module attributes { transform.with_named_sequence } { + func.func @anchor_mfma_16x16x16_mmt_batch(%a : memref<32x128xf16>, %b : memref<64x128xf16>, %init : vector<32x64xf32>) -> vector<32x64xf32> { + // CHECK-LABEL: anchor_mfma_16x16x16_mmt_batch + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.0 : f16 + %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x128xf16>, vector<32x128xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [2, 16]>, <[ BATCHY, LANEY, VECTORX], [8, 4, 4]>>}} + %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<64x128xf16>, vector<64x128xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [4, 16]>, <[ BATCHY, LANEY, VECTORX], [8, 4, 4]>>}} + %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %init : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEY, VECTORX], [2, 4, 4]>, <[ BATCHY, LANEX], [4, 16]>>}} + return %output : vector<32x64xf32> + } + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout16x16x16 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param + + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op + transform.yield + } +} + +// ----- + +#layout = #iree_gpu.mma_layout +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0)> + +builtin.module attributes { transform.with_named_sequence } { + func.func @anchor_wmma_16x16x16_mmt(%a : memref<16x16xf16>, %b : memref<16x16xf16>, %init : vector<16x16xf32>) -> vector<16x16xf32> { + // CHECK-LABEL: anchor_wmma_16x16x16_mmt + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.0 : f16 + %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}} + %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}} + %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, VECTORY, LANEY, VECTORX], [1, 8, 2, 1]>, <[ BATCHY, LANEX], [1, 16]>>}} + return %output : vector<16x16xf32> + } + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %layout16x16x16 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param + + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op + transform.yield + } +} diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp index 8d1716261d6d..dd52c2ce41fa 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp @@ -158,6 +158,8 @@ SmallVector LayoutIterator::State::computeSIMTIndex() const { if (isVectorDimension(name)) { int64_t step{1}; if (name == LayoutDimension::VECTORY) { + assert(ranges.contains(LayoutDimension::VECTORX) && + "Expected VectorX to be specified on layouts with VectorY."); step = ranges.lookup(LayoutDimension::VECTORX).stop; } vecOffset = vecOffset.value_or(0) + it.getPosition() * step;