diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 895512b39..aeee71259 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -336,7 +336,8 @@ def GpuConversion : Pass<"gpu-conversion", "ModuleOp"> { let options = [ Option<"useWmma", "wmma", "bool", /*default=*/"false", - "Use WMMA operations"> + "Use WMMA operations">, + ListOption<"warpTile", "warp-tile", "int64_t", "Warp tile sizes MxNxK">, ]; } @@ -347,7 +348,7 @@ def GpuToCuda : Pass<"gpu-to-cuda", "ModuleOp"> { /*default=*/"\"nvptx64-nvidia-cuda\"", "GPU target triple.">, Option<"gpuChip", "chip", "std::string", - /*default=*/"\"sm_35\"", + /*default=*/"\"sm_70\"", "GPU target architecture.">, Option<"gpuFeatures", "features", "std::string", /*default=*/"\"+ptx60\"", @@ -458,7 +459,11 @@ def LinalgToGpu : Pass<"linalg-to-gpu", "func::FuncOp"> { let options = [ Option<"useWmma", "wmma", "bool", /*default=*/"false", - "Use WMMA operations"> + "Use WMMA operations">, + ListOption<"warpTile", "warp-tile", "int64_t", "Warp tile sizes MxNxK">, + Option<"kTile", "k-tile", "int64_t", + /*default=*/"32", + "GEMM tile size for reduction dimension.">, ]; } diff --git a/lib/TPP/GPU/GpuConversion.cpp b/lib/TPP/GPU/GpuConversion.cpp index d39020974..5820c0b61 100644 --- a/lib/TPP/GPU/GpuConversion.cpp +++ b/lib/TPP/GPU/GpuConversion.cpp @@ -69,7 +69,7 @@ struct GpuConversion : public tpp::impl::GpuConversionBase, // the default lowering for any remaining ops. pm.addNestedPass(createLinalgDeGeneralize()); pm.addNestedPass( - createLinalgToGpu(LinalgToGpuOptions{useWmma})); + createLinalgToGpu(LinalgToGpuOptions{useWmma, warpTile})); pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); // Map loops into GPU kernels. diff --git a/lib/TPP/GPU/GpuPipeline.cpp b/lib/TPP/GPU/GpuPipeline.cpp index 5f2de7979..1664343a2 100644 --- a/lib/TPP/GPU/GpuPipeline.cpp +++ b/lib/TPP/GPU/GpuPipeline.cpp @@ -46,6 +46,11 @@ llvm::cl::opt gpuWmma("gpu-wmma", llvm::cl::desc("Enable GPU WMMA support"), llvm::cl::init(false)); +llvm::cl::list wmmaTileSizes( + "wmma-tile-sizes", llvm::cl::desc("GPU WMMA tile sizes MxNxK"), + llvm::cl::list_init(SmallVector{16, 16, 16}), + llvm::cl::CommaSeparated); + namespace mlir { namespace tpp { #define GEN_PASS_DEF_GPUPIPELINE @@ -70,6 +75,31 @@ GpuType parseGpuOption(StringRef gpuStr) { return *type; } +struct GpuOptions { + std::string triple; + std::string chip; + std::string features; +}; + +GpuOptions getGpuOptions(GpuType gpuType) { + GpuOptions options; + + switch (gpuType) { + case GpuType::Cuda: { + options.triple = "nvptx64-nvidia-cuda"; + options.chip = "sm_70"; + options.features = "+ptx60"; + break; + } + case GpuType::Vulkan: { + // No options needed at the moment. + break; + } + } + + return options; +} + // GPU pipeline - map and lower operations to enable execution on a GPU. struct GpuPipeline : public tpp::impl::GpuPipelineBase, UtilityPassBase { @@ -112,6 +142,7 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, pm.clear(); GpuType gpuType = parseGpuOption(this->gpuBackend); + GpuOptions gpuOptions = getGpuOptions(gpuType); // Tile to split the kernel into threads and blocks. // Use default tiling to handle both packed and unpacked ops. @@ -128,21 +159,18 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, pm.addNestedPass(createCleanup()); // Convert to generic GPU ops. - pm.addPass(createGpuConversion(GpuConversionOptions{gpuWmma})); + pm.addPass( + createGpuConversion(GpuConversionOptions{gpuWmma, wmmaTileSizes})); // Lower GPU ops to the chosen GPU backend. switch (gpuType) { case GpuType::Cuda: { - std::string gpuTriple = "nvptx64-nvidia-cuda"; - std::string gpuChip = "sm_70"; - std::string gpuFeatures = "+ptx60"; - // Perform explicit GPU data transfers only for CUDA as the unified // memory is not currently used here. // Vulkan runner assumes usage of GPU unified memory. pm.addNestedPass(createGpuDataTransfer()); - pm.addPass( - createGpuToCuda(GpuToCudaOptions{gpuTriple, gpuChip, gpuFeatures})); + pm.addPass(createGpuToCuda(GpuToCudaOptions{ + gpuOptions.triple, gpuOptions.chip, gpuOptions.features})); break; } case GpuType::Vulkan: { diff --git a/lib/TPP/GPU/LinalgToGpu.cpp b/lib/TPP/GPU/LinalgToGpu.cpp index a49743c4d..26e260353 100644 --- a/lib/TPP/GPU/LinalgToGpu.cpp +++ b/lib/TPP/GPU/LinalgToGpu.cpp @@ -25,6 +25,8 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include + using namespace mlir; using namespace mlir::tpp; @@ -70,12 +72,16 @@ createGpuBlocksWrapper(Operation *op, ArrayRef blockDims, } // Return true if the operation can be represented with WMMA compute. -static bool supportsMMACompute(linalg::LinalgOp linalgOp) { - if (!(isa_and_nonnull(linalgOp) || - isa_and_nonnull(linalgOp))) { +static bool isMMACompatible(linalg::LinalgOp linalgOp, + ArrayRef warpTile, int kTile) { + if (!(isa(linalgOp) || + isa(linalgOp))) { return false; } + if (warpTile.size() != 3) + return false; + // Only static shapes are supported. if (linalgOp.hasDynamicShape()) return false; @@ -84,122 +90,171 @@ static bool supportsMMACompute(linalg::LinalgOp linalgOp) { auto bType = linalgOp.getDpsInputs()[1].getType().cast(); auto cType = linalgOp.getDpsInits()[0].getType().cast(); - ArrayRef shapeA = aType.getShape(); - ArrayRef shapeC = cType.getShape(); - int64_t m = shapeC[0]; - int64_t n = shapeC[1]; - // Buffer A might be 2D (gemm) or 3D (brgemm) but the last dimension will - // always be reduction. - int64_t k = shapeA.back(); - - // For now, only M-N-K F16[16] x F16[16] x F16[16] WMMA variant is supported. - // TODO: add more WMMA combinations. - return aType.getElementType().isF16() && bType.getElementType().isF16() && - cType.getElementType().isF16() && m == 16 && n == 16 && k == 16; + auto elemTypeA = aType.getElementType(); + auto elemTypeB = bType.getElementType(); + auto elemTypeC = cType.getElementType(); + + // TODO: Add more WMMA combinations. + bool isSupportedPrecision = + (elemTypeA.isF16() && elemTypeB.isF16() && elemTypeC.isF16()) || + (elemTypeA.isF16() && elemTypeB.isF16() && elemTypeC.isF32()); + if (!isSupportedPrecision) + return false; + + auto mDim = cType.getShape()[0]; + auto nDim = cType.getShape()[1]; + auto kDim = aType.getShape().back(); + + // Validate warp tile sizes. + // The computation dimensions must fit into the tiles. + // Reduction dimension tile size has to be compatible + // with the warp tile. + int wmmaTileM = warpTile[0]; + int wmmaTileN = warpTile[1]; + int wmmaTileK = warpTile[2]; + if ((mDim % wmmaTileM != 0) || (nDim % wmmaTileN != 0) || + (kDim % wmmaTileK != 0) || (kTile % wmmaTileK != 0)) { + return false; + } + + return true; } // Fuse a consumer using WMMA operations. // Returns updated store op or nullopt if the fusion fails. -static std::optional -mmaFusion(linalg::LinalgOp rootOp, linalg::LinalgOp consumer, - gpu::SubgroupMmaStoreMatrixOp rootStoreOp, ValueRange storeIndices, - PatternRewriter &rewriter) { +static std::optional> +eltwiseFusion(linalg::LinalgOp rootOp, linalg::LinalgOp consumer, + SmallVector rootStoreOps, + PatternRewriter &rewriter) { + assert(rootStoreOps.size() > 0 && "Requires at least one store op"); + Location loc = rootOp.getLoc(); auto rootOutput = rootOp.getDpsInits()[0]; auto outputType = rootOutput.getType().cast(); + // Must be a floating point type. + // TODO: Add integer support. auto floatType = dyn_cast(outputType.getElementType()); if (!floatType) return std::nullopt; - gpu::MMAMatrixType mmaOutputType = gpu::MMAMatrixType::get( - outputType.getShape(), outputType.getElementType(), "COp"); - auto leadingDim = rootStoreOp.getLeadDimension(); - - Value zero = rewriter.create(loc, 0); - // Insert fused eltwise ops before the store and later replace the store // with a new result. OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(rootStoreOp); + rewriter.setInsertionPoint(rootStoreOps[0]); + + // It is assumed that WMMA tile sizes do not vary between different + // operations i.e., the original workload has been split into + // a series of operations using the same WMMA configuration. + gpu::MMAMatrixType mmaOutputType = rootStoreOps[0].getSrc().getType(); + auto leadingDim = rootStoreOps[0].getLeadDimension(); + + // Collect new results after fusion. + SmallVector fusedRes; - std::optional newStore = std::nullopt; SmallVector operands; if (structured_match::utils::isTwoDAddOp(consumer, &operands)) { // Get the value to be added - load the tile first. // Must be a buffer of the same type - scalar broadcast is not supported. + // TODO: Add support for eltwise with broadcast. auto addValue = (operands[0] != rootOutput) ? operands[0] : operands[1]; if (addValue.getType() != rootOutput.getType()) return std::nullopt; - // Fuse the add into the matmul body. - addValue = rewriter - .create( - loc, mmaOutputType, addValue, ValueRange{zero, zero}, - leadingDim, - /*transpose=*/UnitAttr()) - .getRes(); - auto eltwiseAttr = gpu::MMAElementwiseOp::ADDF; - auto addRes = - rewriter - .create( - loc, mmaOutputType, ValueRange{rootStoreOp.getSrc(), addValue}, - eltwiseAttr) - .getRes(); - // Store the new result. - newStore = rewriter.replaceOpWithNewOp( - rootStoreOp, addRes, rootStoreOp.getDstMemref(), ValueRange{zero, zero}, - leadingDim, - /*transpose=*/UnitAttr()); + + for (gpu::SubgroupMmaStoreMatrixOp rootStoreOp : rootStoreOps) { + auto storeIndices = rootStoreOp.getIndices(); + + // Fuse the add into the matmul body. + auto loadOp = + rewriter + .create( + loc, mmaOutputType, addValue, storeIndices, leadingDim, + /*transpose=*/UnitAttr()) + .getRes(); + auto eltwiseAttr = gpu::MMAElementwiseOp::ADDF; + auto addRes = + rewriter + .create( + loc, mmaOutputType, ValueRange{rootStoreOp.getSrc(), loadOp}, + eltwiseAttr) + .getRes(); + fusedRes.push_back(addRes); + } } else if (structured_match::utils::isTwoDReluOp(consumer, &operands)) { - // Fuse the relu into the matmul body. Value zeroFloat = rewriter.create( loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); + Value zeroTile = rewriter.create( loc, mmaOutputType, zeroFloat); - auto eltwiseAttr = gpu::MMAElementwiseOp::MAXF; - auto maxRes = - rewriter - .create( - loc, mmaOutputType, ValueRange{rootStoreOp.getSrc(), zeroTile}, - eltwiseAttr) - .getRes(); - // Store the new result. - newStore = rewriter.replaceOpWithNewOp( - rootStoreOp, maxRes, rootStoreOp.getDstMemref(), ValueRange{zero, zero}, - leadingDim, - /*transpose=*/UnitAttr()); + for (auto rootStoreOp : rootStoreOps) { + // Fuse the relu into the matmul body. + auto eltwiseAttr = gpu::MMAElementwiseOp::MAXF; + auto maxRes = + rewriter + .create( + loc, mmaOutputType, + ValueRange{rootStoreOp.getSrc(), zeroTile}, eltwiseAttr) + .getRes(); + fusedRes.push_back(maxRes); + } } else { // Not a fusable operation. Bail out. return std::nullopt; } + // Fusion must have failed, if number of new results is different. + // Bail out. + if (fusedRes.size() != rootStoreOps.size()) + return std::nullopt; + + // Store the new result. + SmallVector newStores; + for (size_t i = 0; i < rootStoreOps.size(); i++) { + auto storeIndices = rootStoreOps[i].getIndices(); + + auto newStore = rewriter.create( + loc, fusedRes[i], rootStoreOps[i].getDstMemref(), storeIndices, + leadingDim, + /*transpose=*/UnitAttr()); + newStores.push_back(newStore); + } + + // Replace store ops and cleanup standalone consumer. + for (size_t i = 0; i < rootStoreOps.size(); i++) + rewriter.replaceOp(rootStoreOps[i], newStores[i]); + rewriter.eraseOp(consumer); - return newStore; + return newStores; } // Fuse a consumer using scalar operations. +// TODO: Extend scalar fusion to support multiple stores. +// // Returns updated store op or nullopt if the fusion fails. -static std::optional scalarFusion(linalg::LinalgOp rootOp, - linalg::LinalgOp consumer, - memref::StoreOp rootStoreOp, - ValueRange storeIndices, - PatternRewriter &rewriter) { +static std::optional eltwiseFusion(linalg::LinalgOp rootOp, + linalg::LinalgOp consumer, + memref::StoreOp rootStoreOp, + PatternRewriter &rewriter) { Location loc = rootOp.getLoc(); auto rootOutput = rootOp.getDpsInits()[0]; auto outputType = rootOutput.getType().cast(); + // Must be a floating point type. + // TODO: Add integer support. auto floatType = dyn_cast(outputType.getElementType()); if (!floatType) return std::nullopt; + auto storeIndices = rootStoreOp.getIndices(); + // Insert fused eltwise ops before the store and later replace the store // with a new result. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(rootStoreOp); - std::optional newStore = std::nullopt; + std::optional newStore = std::nullopt; SmallVector operands; if (structured_match::utils::isTwoDAddOp(consumer, &operands)) { // Get the value to be added. Load the element first, if necessary. @@ -233,15 +288,12 @@ static std::optional scalarFusion(linalg::LinalgOp rootOp, return newStore; } -// Fuse elementwise consumers. -// A naive fusion strategy that looks at the other operations after the root +// Find operations fusable with the given root op. +// +// A simple fusion strategy that looks at the other operations after the root // linalg op and tries to fuse them. -// Attemps bails on the first mismatch. -// Returns updated store op. -static Operation *fuseEltwiseConsumers(linalg::LinalgOp rootOp, - Operation *rootStoreOp, - ValueRange storeIndices, - PatternRewriter &rewriter) { +static SmallVector +getFusableConsumers(linalg::LinalgOp rootOp) { auto *parentOp = rootOp->getParentOp(); auto rootOutput = rootOp.getDpsInits()[0]; @@ -265,37 +317,52 @@ static Operation *fuseEltwiseConsumers(linalg::LinalgOp rootOp, auto outBuf = consumer.getDpsInitOperand(0)->get(); // Check that the op reuses the same output buffer as the root op. // Otherwise, it is assumed that the op cannot be fused. + // TODO: Consider adding support for eltwise with broadcast. if (outBuf != rootOutput) break; consumers.push_back(consumer); } + return consumers; +} + +// Fuse elementwise consumers within a GPU kernel. +// +// Fusion bails on the first mismatch. +// Returns updated store ops. +template +static StoreTy fuseEltwiseConsumers(linalg::LinalgOp rootOp, + StoreTy rootStoreOps, + PatternRewriter &rewriter) { + // Constrain conversion to the supported fusion types. + static_assert( + llvm::is_one_of>::value); + + auto consumers = getFusableConsumers(rootOp); + for (auto op : consumers) { - std::optional updatedStoreOp = std::nullopt; - if (auto storeOp = dyn_cast(rootStoreOp)) { - updatedStoreOp = - scalarFusion(rootOp, op, storeOp, storeIndices, rewriter); - } else if (auto mmaStore = - dyn_cast(rootStoreOp)) { - updatedStoreOp = mmaFusion(rootOp, op, mmaStore, storeIndices, rewriter); - } + std::optional updatedStoreOps = std::nullopt; - // Not a fusable operation. Bail out. - if (!updatedStoreOp) + updatedStoreOps = eltwiseFusion(rootOp, op, rootStoreOps, rewriter); + + // Failed to fuse operation. Bail out. + if (!updatedStoreOps) break; - rootStoreOp = *updatedStoreOp; + rootStoreOps = *updatedStoreOps; } - return rootStoreOp; + return rootStoreOps; } // Create WMMA instructions out of matmul-like operation. static LogicalResult gemmToGpuMMA(linalg::LinalgOp linalgOp, + ArrayRef warpTile, int kTile, PatternRewriter &rewriter) { - assert((isa_and_nonnull(linalgOp) || - isa_and_nonnull(linalgOp)) && + assert((isa(linalgOp) || + isa(linalgOp)) && "Requires a matmul like op for MMA lowering"); Location loc = linalgOp.getLoc(); @@ -316,13 +383,6 @@ static LogicalResult gemmToGpuMMA(linalg::LinalgOp linalgOp, auto typeB = matB.getType().cast(); auto typeC = matC.getType().cast(); - gpu::MMAMatrixType mmaTypeA = gpu::MMAMatrixType::get( - typeA.getShape().take_back(2), typeA.getElementType(), "AOp"); - gpu::MMAMatrixType mmaTypeB = gpu::MMAMatrixType::get( - typeB.getShape().take_back(2), typeB.getElementType(), "BOp"); - gpu::MMAMatrixType mmaTypeC = - gpu::MMAMatrixType::get(typeC.getShape(), typeC.getElementType(), "COp"); - auto stridesA = utils::getStaticStrides(matA); auto stridesB = utils::getStaticStrides(matB); auto stridesC = utils::getStaticStrides(matC); @@ -337,6 +397,21 @@ static LogicalResult gemmToGpuMMA(linalg::LinalgOp linalgOp, "Expect unit stride in the innermost dimension for MMA operations"); } + int dimM = typeC.getShape()[0]; + int dimN = typeC.getShape()[1]; + int dimK = typeA.getShape().back(); + + int64_t wmmaTileM = warpTile[0]; + int64_t wmmaTileN = warpTile[1]; + int64_t wmmaTileK = warpTile[2]; + + gpu::MMAMatrixType mmaTypeA = gpu::MMAMatrixType::get( + {wmmaTileM, wmmaTileK}, typeA.getElementType(), "AOp"); + gpu::MMAMatrixType mmaTypeB = gpu::MMAMatrixType::get( + {wmmaTileK, wmmaTileN}, typeB.getElementType(), "BOp"); + gpu::MMAMatrixType mmaTypeC = gpu::MMAMatrixType::get( + {wmmaTileM, wmmaTileN}, typeC.getElementType(), "COp"); + bool isBrgemm = isa(linalgOp); // Skip batch dimension stride in case of brgemm. @@ -359,60 +434,165 @@ static LogicalResult gemmToGpuMMA(linalg::LinalgOp linalgOp, rewriter.setInsertionPoint(parallelLoop.getBody()->getTerminator()); // Fetch the inital value of the output element. - Value tileC = rewriter - .create( - loc, mmaTypeC, matC, ValueRange{zero, zero}, ldc, - /*transpose=*/UnitAttr()) - .getRes(); + SmallVector tilesC; + for (int m = 0; m < dimM; m += wmmaTileM) { + for (int n = 0; n < dimN; n += wmmaTileN) { + Value rowIdx = rewriter.create(loc, m); + Value colIdx = rewriter.create(loc, n); + Value tileC = + rewriter + .create( + loc, mmaTypeC, matC, ValueRange{rowIdx, colIdx}, ldc, + /*transpose=*/UnitAttr()) + .getRes(); + tilesC.push_back(tileC); + } + } + + // Create a loop and step into it. + auto startLoop = [&](int lb, int ub, int step) -> scf::ForOp { + Value lbCst = rewriter.create(loc, lb); + Value ubCst = rewriter.create(loc, ub); + Value stepCst = rewriter.create(loc, step); + scf::ForOp loopOp = + rewriter.create(loc, lbCst, ubCst, stepCst, tilesC); + rewriter.setInsertionPointToStart(loopOp.getBody()); + return loopOp; + }; + auto getLoopIterValues = [&](scf::ForOp loopOp) -> SmallVector { + SmallVector loopIterVals; + for (auto iterArg : loopOp.getRegionIterArgs()) + loopIterVals.push_back(iterArg); + return loopIterVals; + }; + // Construct and move into batch reduction loop. + // Propagate output values as iter args. scf::ForOp batchLoop; Value batchIv; if (isBrgemm) { - Value batch = - rewriter.create(loc, typeA.getShape()[0]); - batchLoop = - rewriter.create(loc, zero, batch, one, ValueRange{tileC}); - rewriter.setInsertionPointToStart(batchLoop.getBody()); + batchLoop = startLoop(0, typeA.getShape()[0], 1); batchIv = batchLoop.getInductionVar(); - tileC = batchLoop.getRegionIterArg(0); + tilesC = getLoopIterValues(batchLoop); } - Value tileA = rewriter - .create( - loc, mmaTypeA, matA, - isBrgemm ? ValueRange{batchIv, zero, zero} - : ValueRange{zero, zero}, - lda, - /*transpose=*/UnitAttr()) - .getRes(); - Value tileB = rewriter - .create( - loc, mmaTypeB, matB, - isBrgemm ? ValueRange{batchIv, zero, zero} - : ValueRange{zero, zero}, - ldb, /*transpose=*/UnitAttr()) - .getRes(); - - Value result = - rewriter - .create(loc, tileC.getType(), tileA, tileB, - tileC, /*a_transpose=*/UnitAttr(), - /*b_transpose=*/UnitAttr()) - .getRes(); + // Construct and move into GEMM reduction dimension tiling loop. + // Propagate output values as iter args. + scf::ForOp kDimLoop = startLoop(0, dimK, kTile); + Value kDimIv = kDimLoop.getInductionVar(); + tilesC = getLoopIterValues(kDimLoop); + + // Load A sub-tiles. + SmallVector tilesA; + for (int m = 0; m < dimM; m += wmmaTileM) { + for (int k = 0; k < kTile; k += wmmaTileK) { + Value rowOffset = rewriter.create(loc, m); + Value colOffset = rewriter.create(loc, k); + + Value rowIdx = rowOffset; + Value colIdx = rewriter.create(loc, kDimIv, colOffset); + + Value tileA = rewriter + .create( + loc, mmaTypeA, matA, + isBrgemm ? ValueRange{batchIv, rowIdx, colIdx} + : ValueRange{rowIdx, colIdx}, + lda, + /*transpose=*/UnitAttr()) + .getRes(); + tilesA.push_back(tileA); + } + } + + // Load B sub-tiles. + SmallVector tilesB; + for (int k = 0; k < kTile; k += wmmaTileK) { + for (int n = 0; n < dimN; n += wmmaTileN) { + Value rowOffset = rewriter.create(loc, k); + Value colOffset = rewriter.create(loc, n); + + Value rowIdx = rewriter.create(loc, kDimIv, rowOffset); + Value colIdx = colOffset; + + Value tileB = rewriter + .create( + loc, mmaTypeB, matB, + isBrgemm ? ValueRange{batchIv, rowIdx, colIdx} + : ValueRange{rowIdx, colIdx}, + ldb, /*transpose=*/UnitAttr()) + .getRes(); + tilesB.push_back(tileB); + } + } + const int numTilesM = dimM / wmmaTileM; + const int numTilesN = dimN / wmmaTileN; + const int numTilesK = kTile / wmmaTileK; + + // Compute sub-tiles of the C tile. + // + // Iterate over the reduction dimension sub-tiles as the outermost + // loop to minimize read after write conflicts between partial + // computations of the same C sub-tile. + // + // Initialize sub-tiles with the loaded C tiles. + SmallVector results = tilesC; + for (int k = 0; k < numTilesK; k++) { + for (int m = 0; m < numTilesM; m++) { + for (int n = 0; n < numTilesN; n++) { + int aIdx = m * numTilesK + k; + int bIdx = k * numTilesN + n; + int cIdx = m * numTilesN + n; + + Value result = rewriter + .create( + loc, tilesC[cIdx].getType(), tilesA[aIdx], + tilesB[bIdx], results[cIdx], + /*a_transpose=*/UnitAttr(), + /*b_transpose=*/UnitAttr()) + .getRes(); + // Update sub-tile partial result. + results[cIdx] = result; + } + } + } + + // Create loop terminator and exit the loop. + auto terminateLoop = [&](scf::ForOp loopOp, SmallVector resultValues) { + rewriter.setInsertionPointToEnd(loopOp.getBody()); + rewriter.create(loc, resultValues); + rewriter.setInsertionPointAfter(loopOp); + }; + + // Terminate and exit reduction dim loop. + terminateLoop(kDimLoop, results); + results = kDimLoop.getResults(); + + // Terminate and exit batch reduce loop. if (isBrgemm) { - rewriter.setInsertionPointToEnd(batchLoop.getBody()); - rewriter.create(loc, ValueRange{result}); - result = batchLoop.getResults()[0]; - rewriter.setInsertionPointAfter(batchLoop); + terminateLoop(batchLoop, results); + results = batchLoop.getResults(); } - // Write back the total sum to the output buffer. - auto storeOp = rewriter.create( - loc, result, matC, ValueRange{zero, zero}, ldc, /*transpose=*/UnitAttr()); + // Write back the final C sub-tiles results to the output buffer. + SmallVector storeOps; + for (int m = 0; m < numTilesM; m++) { + for (int n = 0; n < numTilesN; n++) { + int resIdx = m * numTilesN + n; + + Value rowIdx = + rewriter.create(loc, m * wmmaTileM); + Value colIdx = + rewriter.create(loc, n * wmmaTileN); + auto storeOp = rewriter.create( + loc, results[resIdx], matC, ValueRange{rowIdx, colIdx}, ldc, + /*transpose=*/UnitAttr()); + storeOps.push_back(storeOp); + } + } - (void)fuseEltwiseConsumers(linalgOp, storeOp, ValueRange{zero, zero}, - rewriter); + (void)fuseEltwiseConsumers>( + linalgOp, storeOps, rewriter); rewriter.eraseOp(linalgOp); @@ -422,8 +602,8 @@ static LogicalResult gemmToGpuMMA(linalg::LinalgOp linalgOp, // Create loops out of matmul-like operation. static LogicalResult gemmToGpuLoops(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) { - assert((isa_and_nonnull(linalgOp) || - isa_and_nonnull(linalgOp)) && + assert((isa(linalgOp) || + isa(linalgOp)) && "Requires a matmul like op for loop lowering"); Location loc = linalgOp.getLoc(); @@ -516,71 +696,58 @@ static LogicalResult gemmToGpuLoops(linalg::LinalgOp linalgOp, auto storeOp = rewriter.create(loc, result, matC, parallelIvs); - (void)fuseEltwiseConsumers(linalgOp, storeOp, parallelIvs, rewriter); + (void)fuseEltwiseConsumers(linalgOp, storeOp, rewriter); rewriter.eraseOp(linalgOp); return success(); } -// Convert linalg.matmul to GPU-compatible kernel. -struct ConvertGemmToGpu : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Convert linalg.matmul or linalg.batch_reduce_matmul to GPU-compatible kernel. +template +struct ConvertGemmLikeToGpu : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + // Constrain conversion to the supported GEMM-like ops. + static_assert(llvm::is_one_of::value); - ConvertGemmToGpu(MLIRContext *ctx, bool useWmma) - : OpRewritePattern(ctx), useWmma(useWmma) {} + ConvertGemmLikeToGpu(MLIRContext *ctx, LinalgToGpuOptions options) + : OpRewritePattern(ctx), options(options) {} - LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + LogicalResult matchAndRewrite(LinalgOpTy gemmLikeOp, PatternRewriter &rewriter) const override { - if (!matmulOp.hasPureBufferSemantics()) { + if (!gemmLikeOp.hasPureBufferSemantics()) { return rewriter.notifyMatchFailure( - matmulOp, "Linalg gemm to GPU expects memref type"); + gemmLikeOp, "Linalg brgemm to GPU expects memref type"); } - if (matmulOp.hasDynamicShape()) { + if (gemmLikeOp.hasDynamicShape()) { return rewriter.notifyMatchFailure( - matmulOp, "Expect static shape when mapping to GPU"); + gemmLikeOp, "Expect static shape when mapping to GPU"); } - if (useWmma && supportsMMACompute(matmulOp)) - return gemmToGpuMMA(matmulOp, rewriter); - return gemmToGpuLoops(matmulOp, rewriter); - } - -private: - bool useWmma; -}; - -// Convert linalg.batch_reduce_matmul to GPU-compatible kernel. -struct ConvertBrgemmToGpu - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - ConvertBrgemmToGpu(MLIRContext *ctx, bool useWmma) - : OpRewritePattern(ctx), useWmma(useWmma) {} + // Ensure that reduction dimension tiling also works for smaller workloads. + auto aType = + gemmLikeOp.getDpsInputs()[0].getType().template cast(); + auto kDim = aType.getShape().back(); + auto kTile = kDim < options.kTile ? kDim : options.kTile; - LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp brgemmOp, - PatternRewriter &rewriter) const override { - if (!brgemmOp.hasPureBufferSemantics()) { - return rewriter.notifyMatchFailure( - brgemmOp, "Linalg brgemm to GPU expects memref type"); + if (options.useWmma && + isMMACompatible(gemmLikeOp, options.warpTile, kTile)) { + return gemmToGpuMMA(gemmLikeOp, options.warpTile, kTile, rewriter); } - if (brgemmOp.hasDynamicShape()) { - return rewriter.notifyMatchFailure( - brgemmOp, "Expect static shape when mapping to GPU"); - } - - if (useWmma && supportsMMACompute(brgemmOp)) - return gemmToGpuMMA(brgemmOp, rewriter); - return gemmToGpuLoops(brgemmOp, rewriter); + // TODO: Add warp and K dim tiling to looped implementation. + return gemmToGpuLoops(gemmLikeOp, rewriter); } private: - bool useWmma; + LinalgToGpuOptions options; }; -void populateLinalgToGpuPatterns(RewritePatternSet &patterns, bool useWmma) { - patterns.add(patterns.getContext(), - useWmma); +void populateLinalgToGpuPatterns(RewritePatternSet &patterns, + LinalgToGpuOptions options) { + patterns.add, + ConvertGemmLikeToGpu>( + patterns.getContext(), options); } struct LinalgToGpu : public tpp::impl::LinalgToGpuBase { @@ -588,7 +755,8 @@ struct LinalgToGpu : public tpp::impl::LinalgToGpuBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateLinalgToGpuPatterns(patterns, useWmma); + populateLinalgToGpuPatterns(patterns, + LinalgToGpuOptions{useWmma, warpTile, kTile}); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; diff --git a/test/GPU/CUDA/Integration/wmma/brgemm-wmma-tiled.mlir b/test/GPU/CUDA/Integration/wmma/brgemm-wmma-tiled.mlir new file mode 100644 index 000000000..143ce9424 --- /dev/null +++ b/test/GPU/CUDA/Integration/wmma/brgemm-wmma-tiled.mlir @@ -0,0 +1,14 @@ +// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ +// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ +// RUN: -entry-point-result=void -e entry 2>&1 | \ +// RUN: FileCheck %s + +func.func @entry(%arg0: memref<16x32x32xf16>, + %arg1: memref<16x32x32xf16>, + %arg2: memref<32x32xf16>) -> memref<32x32xf16> { + linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<16x32x32xf16>, memref<16x32x32xf16>) + outs(%arg2 : memref<32x32xf16>) + return %arg2 : memref<32x32xf16> +} + +// CHECK-COUNT-32: ( 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513 ) diff --git a/test/GPU/CUDA/Integration/wmma/gemm-wmma-tiled.mlir b/test/GPU/CUDA/Integration/wmma/gemm-wmma-tiled.mlir new file mode 100644 index 000000000..0e850cc4e --- /dev/null +++ b/test/GPU/CUDA/Integration/wmma/gemm-wmma-tiled.mlir @@ -0,0 +1,11 @@ +// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ +// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ +// RUN: -entry-point-result=void -e entry 2>&1 | \ +// RUN: FileCheck %s + +func.func @entry(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>) -> memref<32x32xf16> { + linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) outs(%arg2 : memref<32x32xf16>) + return %arg2 : memref<32x32xf16> +} + +// CHECK-COUNT-32: ( 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33 ) diff --git a/test/GPU/CUDA/Integration/wmma/mlp-wmma-tiled.mlir b/test/GPU/CUDA/Integration/wmma/mlp-wmma-tiled.mlir new file mode 100644 index 000000000..a71f4f889 --- /dev/null +++ b/test/GPU/CUDA/Integration/wmma/mlp-wmma-tiled.mlir @@ -0,0 +1,28 @@ +// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ +// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ +// RUN: -entry-point-result=void -e entry 2>&1 | \ +// RUN: FileCheck %s + +// XFAIL:* +// See: #870 + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @entry(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>, %arg3: memref<32x32xf16>) -> memref<32x32xf16> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f16 + linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) outs(%arg3 : memref<32x32xf16>) + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<32x32xf16>) outs(%arg3 : memref<32x32xf16>) { + ^bb0(%in: f16, %out: f16): + %0 = arith.addf %in, %out : f16 + linalg.yield %0 : f16 + } + linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg3 :memref<32x32xf16>) { + ^bb0(%out: f16): + %0 = arith.maximumf %out, %cst : f16 + linalg.yield %0 : f16 + } + return %arg3 : memref<32x32xf16> +} + +// CHECK-COUNT-32: ( 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34 ) diff --git a/test/GPU/CUDA/Integration/wmma/mlp-wmma.mlir b/test/GPU/CUDA/Integration/wmma/mlp-wmma.mlir new file mode 100644 index 000000000..249e86f38 --- /dev/null +++ b/test/GPU/CUDA/Integration/wmma/mlp-wmma.mlir @@ -0,0 +1,29 @@ +// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ +// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ +// RUN: -entry-point-result=void -e entry 2>&1 | \ +// RUN: FileCheck %s + +// XFAIL:* +// See: #870 + +#map = affine_map<(d0, d1) -> (d0, d1)> + +func.func @entry(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf16>) -> memref<16x16xf16> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f16 + linalg.matmul ins(%arg0, %arg1 : memref<16x16xf16>, memref<16x16xf16>) outs(%arg3 : memref<16x16xf16>) + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<16x16xf16>) outs(%arg3 : memref<16x16xf16>) { + ^bb0(%in: f16, %out: f16): + %0 = arith.addf %in, %out : f16 + linalg.yield %0 : f16 + } + linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg3 :memref<16x16xf16>) { + ^bb0(%out: f16): + %0 = arith.maximumf %out, %cst : f16 + linalg.yield %0 : f16 + } + return %arg3 : memref<16x16xf16> +} + +// CHECK-COUNT-16: ( 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18 ) diff --git a/test/GPU/linalg-to-gpu-wmma.mlir b/test/GPU/linalg-to-gpu-wmma.mlir index 9f685e104..75a440634 100644 --- a/test/GPU/linalg-to-gpu-wmma.mlir +++ b/test/GPU/linalg-to-gpu-wmma.mlir @@ -1,4 +1,4 @@ -// RUN: tpp-opt %s -linalg-to-gpu=wmma -split-input-file | FileCheck %s +// RUN: tpp-opt %s -linalg-to-gpu="wmma=1 warp-tile=16,16,16" -canonicalize -split-input-file | FileCheck %s func.func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, @@ -11,20 +11,99 @@ func.func @matmul(%arg0: memref<16x16xf16>, // CHECK-LABEL: func.func @matmul( // CHECK-SAME: %[[A:.+]]: memref<16x16xf16>, %[[B:.+]]: memref<16x16xf16>, %[[C:.+]]: memref<16x16xf16> // CHECK-DAG: %[[subgroup_size:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[one:.+]] = arith.constant 1 : index -// CHECK: scf.parallel {{.*}}to (%[[one]], %[[one]]) -// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) -// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 -// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] -// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[C]]{{.*}}leadDimension = 16 -// CHECK: scf.reduce +// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) +// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 +// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 +// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 +// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] +// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[C]]{{.*}}leadDimension = 16 // CHECK: scf.reduce // CHECK: } // ----- +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @matmul_wide_tiled(%arg0: memref<16x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<16x32xf16>) { + linalg.matmul ins(%arg0, %arg1 : memref<16x32xf16>, memref<32x32xf16>) outs(%arg2 : memref<16x32xf16>) + return +} + +// Assumes 16x16 WMMA tiles. +// +// CHECK-LABEL: func.func @matmul_wide_tiled( +// CHECK-SAME: %[[A:.+]]: memref<16x32xf16>, %[[B:.+]]: memref<32x32xf16>, %[[C:.+]]: memref<16x32xf16> +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 +// CHECK-DAG: %[[c16:.+]] = arith.constant 16 +// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[C]] +// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[A]] +// CHECK-DAG: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[c0]], %[[c0]] +// CHECK-DAG: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[c0]], %[[c16]] +// CHECK-DAG: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[c16]], %[[c0]] +// CHECK-DAG: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[c16]], %[[c16]] +// CHECK-COUNT-4: gpu.subgroup_mma_compute +// CHECK-COUNT-2: gpu.subgroup_mma_store_matrix{{.*}}, %[[C]] + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @matmul_tall_tiled(%arg0: memref<32x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<32x16xf16>) { + linalg.matmul ins(%arg0, %arg1 : memref<32x16xf16>, memref<16x16xf16>) outs(%arg2 : memref<32x16xf16>) + return +} + +// CHECK-LABEL: func.func @matmul_tall_tiled( +// CHECK-SAME: %[[A:.+]]: memref<32x16xf16>, %[[B:.+]]: memref<16x16xf16>, %[[C:.+]]: memref<32x16xf16> +// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[C]] +// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[A]] +// CHECK-COUNT-1: gpu.subgroup_mma_load_matrix %[[B]] +// CHECK-COUNT-2: gpu.subgroup_mma_compute +// CHECK-COUNT-2: gpu.subgroup_mma_store_matrix + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @matmul_2D_tiled(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>) { + linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) outs(%arg2 : memref<32x32xf16>) + return +} + +// CHECK-LABEL: func.func @matmul_2D_tiled( +// CHECK-SAME: %[[A:.+]]: memref<32x32xf16>, %[[B:.+]]: memref<32x32xf16>, %[[C:.+]]: memref<32x32xf16> +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[C]] +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[A]] +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[B]] +// CHECK-COUNT-8: gpu.subgroup_mma_compute +// CHECK-COUNT-4: gpu.subgroup_mma_store_matrix + +// ----- + +func.func @matmul_K_dim_tiled(%arg0: memref<16x64xf16>, %arg1: memref<64x16xf16>, %arg2: memref<16x16xf16>) { + linalg.matmul ins(%arg0, %arg1 : memref<16x64xf16>, memref<64x16xf16>) outs(%arg2 : memref<16x16xf16>) + return +} + +// CHECK-LABEL: func.func @matmul_K_dim_tiled( +// CHECK-SAME: %[[A:.+]]: memref<16x64xf16>, %[[B:.+]]: memref<64x16xf16>, %[[C:.+]]: memref<16x16xf16> +// CHECK-DAG: %[[zero:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[kStep:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[kUB:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[wmmaSizeK:.+]] = arith.constant 16 : index +// CHECK-COUNT-1: %[[cTile:.+]] = gpu.subgroup_mma_load_matrix %[[C]] +// CHECK: %[[loopRes:.+]] = scf.for %[[iv:.+]] = %[[zero]] to %[[kUB]] step %[[kStep]] iter_args(%[[acc_tile:.+]] = %[[cTile]]) +// CHECK: gpu.subgroup_mma_load_matrix %[[A]]{{\[}}%[[zero]], %[[iv]] +// CHECK: %[[aCol:.+]] = arith.addi %[[iv]], %[[wmmaSizeK]] +// CHECK: gpu.subgroup_mma_load_matrix %[[A]]{{\[}}%[[zero]], %[[aCol]] +// CHECK: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[iv]], %[[zero]] +// CHECK: %[[bRow:.+]] = arith.addi %[[iv]], %[[wmmaSizeK]] +// CHECK: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[bRow]], %[[zero]] +// CHECK: gpu.subgroup_mma_compute +// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute +// CHECK: scf.yield %[[res]] +// CHECK: } +// CHECK: gpu.subgroup_mma_store_matrix %[[loopRes]], %[[C]] + +// ----- + func.func @batch_reduce_matmul(%arg0: memref<64x16x16xf16>, %arg1: memref<64x16x16xf16>, %arg2: memref<16x16xf16>) { @@ -33,28 +112,76 @@ func.func @batch_reduce_matmul(%arg0: memref<64x16x16xf16>, return } - // CHECK-LABEL: func.func @batch_reduce_matmul( // CHECK-SAME: %[[A:.+]]: memref<64x16x16xf16>, %[[B:.+]]: memref<64x16x16xf16>, %[[C:.+]]: memref<16x16xf16> // CHECK-DAG: %[[subgroup_size:.+]] = arith.constant 32 : index // CHECK-DAG: %[[batch:.+]] = arith.constant 64 : index // CHECK-DAG: %[[one:.+]] = arith.constant 1 : index -// CHECK: scf.parallel {{.*}}to (%[[one]], %[[one]]) -// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) -// CHECK: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 -// CHECK: %[[res:.+]] = scf.for {{.*}}to %[[batch]] {{.*}}iter_args(%[[acc_tile:.*]] = %[[tileC]]) -// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 -// CHECK: %[[part_sum:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[acc_tile]] -// CHECK: scf.yield %[[part_sum]] -// CHECK: } -// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[C]]{{.*}}leadDimension = 16 -// CHECK: scf.reduce +// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) +// CHECK: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 +// CHECK: %[[res:.+]] = scf.for {{.*}}to %[[batch]] {{.*}}iter_args(%[[acc_tile:.*]] = %[[tileC]]) +// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 +// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 +// CHECK: %[[part_sum:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[acc_tile]] +// CHECK: scf.yield %[[part_sum]] +// CHECK: } +// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[C]]{{.*}}leadDimension = 16 // CHECK: scf.reduce // CHECK: } // ----- +func.func @batch_reduce_matmul_2D_tiled(%arg0: memref<64x32x32xf16>, + %arg1: memref<64x32x32xf16>, + %arg2: memref<32x32xf16>) { + linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<64x32x32xf16>, memref<64x32x32xf16>) + outs(%arg2 : memref<32x32xf16>) + return +} + +// CHECK-LABEL: func.func @batch_reduce_matmul_2D_tiled( +// CHECK-SAME: %[[A:.+]]: memref<64x32x32xf16>, %[[B:.+]]: memref<64x32x32xf16>, %[[C:.+]]: memref<32x32xf16> +// CHECK-DAG: %[[batch:.+]] = arith.constant 64 : index +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[C]] +// CHECK: %[[res:.+]] = scf.for {{.*}}to %[[batch]] {{.*}}iter_args +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[A]] +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[B]] +// CHECK-COUNT-8: gpu.subgroup_mma_compute +// CHECK: scf.yield{{.*}}: !gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: } +// CHECK-COUNT-4: gpu.subgroup_mma_store_matrix + +// ----- + +func.func @batch_reduce_matmul_K_dim_tiled(%arg0: memref<32x16x64xf16>, + %arg1: memref<32x64x16xf16>, + %arg2: memref<16x16xf16>) { + linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<32x16x64xf16>, memref<32x64x16xf16>) + outs(%arg2 : memref<16x16xf16>) + return +} + +// CHECK-LABEL: func.func @batch_reduce_matmul_K_dim_tiled( +// CHECK-SAME: %[[A:.+]]: memref<32x16x64xf16>, %[[B:.+]]: memref<32x64x16xf16>, %[[C:.+]]: memref<16x16xf16> +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c16:.+]] = arith.constant 16 : index +// CHECK-DAG: %[[c32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[c64:.+]] = arith.constant 64 : index +// CHECK-COUNT-1: %[[cTile:.+]] = gpu.subgroup_mma_load_matrix %[[C]] +// CHECK: %[[batchLoopRes:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c32]] step %[[c1]] iter_args(%[[acc_batch:.+]] = %[[cTile]]) +// CHECK: %[[kLoopRes:.+]] = scf.for %[[iv:.+]] = %[[zero]] to %[[kUB]] step %[[kStep]] iter_args(%[[acc_k_dim:.+]] = %[[acc_batch]]) +// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[A]] +// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[B]] +// CHECK: gpu.subgroup_mma_compute +// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute +// CHECK: scf.yield %[[res]] +// CHECK: scf.yield %[[kLoopRes]] +// CHECK: } +// CHECK: gpu.subgroup_mma_store_matrix %[[batchLoopRes]], %[[C]] + +// ----- + func.func @matmul_strided_memrefs(%arg0: memref<16x32x16xf16>, %arg1: memref<16x64x16xf16>, %arg2: memref<32x32xf16>) { %subview = memref.subview %arg0[0, 0, 0] [16, 1, 16] [1, 1, 1] : memref<16x32x16xf16> to memref<16x16xf16, strided<[512, 1], offset: 0>> @@ -77,14 +204,12 @@ func.func @matmul_strided_memrefs(%arg0: memref<16x32x16xf16>, %arg1: memref<16x // CHECK-DAG: %[[subA:.+]] = memref.subview %[[A]] // CHECK-DAG: %[[subB:.+]] = memref.subview %[[B]] // CHECK-DAG: %[[subC:.+]] = memref.subview %[[C]] -// CHECK: scf.parallel {{.*}}to (%[[one]], %[[one]]) -// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) -// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[subC]]{{.*}}leadDimension = 32 -// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[subA]]{{.*}}leadDimension = 512 -// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[subB]]{{.*}}leadDimension = 1024 -// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] -// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[subC]]{{.*}}leadDimension = 32 -// CHECK: scf.reduce +// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) +// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[subC]]{{.*}}leadDimension = 32 +// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[subA]]{{.*}}leadDimension = 512 +// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[subB]]{{.*}}leadDimension = 1024 +// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] +// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[subC]]{{.*}}leadDimension = 32 // CHECK: scf.reduce // CHECK: } @@ -102,18 +227,6 @@ func.func @wrong_data_type(%arg0: memref<16x16xf32>, // CHECK-LABEL: func.func @wrong_data_type( // CHECK-NOT: gpu.{{.*}}_mma_ -// Operands' shapes do not match supported WMMA shapes. -func.func @wrong_shapes(%arg0: memref<32x32xf16>, - %arg1: memref<32x32xf16>, - %arg2: memref<32x32xf16>) { - linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) - outs(%arg2 : memref<32x32xf16>) - return -} - -// CHECK-LABEL: func.func @wrong_shapes( -// CHECK-NOT: gpu.{{.*}}_mma_ - // ----- // Dynamic shapes are not supported. @@ -167,17 +280,57 @@ func.func @matmul_add_relu(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, % // CHECK-DAG: %[[subgroup_size:.+]] = arith.constant 32 : index // CHECK-DAG: %[[one:.+]] = arith.constant 1 : index // CHECK-DAG: %[[zeroF16:.+]] = arith.constant 0.000000e+00 : f16 -// CHECK: scf.parallel {{.*}}to (%[[one]], %[[one]]) -// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) -// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 -// CHECK: %[[compRes:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] -// CHECK: %[[tileBias:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{.*}}leadDimension = 16 -// CHECK: %[[addRes:.+]] = gpu.subgroup_mma_elementwise addf %[[compRes]], %[[tileBias]] -// CHECK: %[[tileCstZero:.+]] = gpu.subgroup_mma_constant_matrix %[[zeroF16]] -// CHECK: %[[reluRes:.+]] = gpu.subgroup_mma_elementwise maxf %[[addRes]], %[[tileCstZero]] -// CHECK: gpu.subgroup_mma_store_matrix %[[reluRes]], %[[C]]{{.*}}leadDimension = 16 -// CHECK: scf.reduce +// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) +// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 +// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 +// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 +// CHECK: %[[compRes:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] +// CHECK: %[[tileBias:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{.*}}leadDimension = 16 +// CHECK: %[[addRes:.+]] = gpu.subgroup_mma_elementwise addf %[[compRes]], %[[tileBias]] +// CHECK: %[[tileCstZero:.+]] = gpu.subgroup_mma_constant_matrix %[[zeroF16]] +// CHECK: %[[reluRes:.+]] = gpu.subgroup_mma_elementwise maxf %[[addRes]], %[[tileCstZero]] +// CHECK: gpu.subgroup_mma_store_matrix %[[reluRes]], %[[C]]{{.*}}leadDimension = 16 // CHECK: scf.reduce // CHECK: } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @matmul_add_relu_2D_tiled(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>, %arg3: memref<32x32xf16>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f16 + linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) outs(%arg3 : memref<32x32xf16>) + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<32x32xf16>) outs(%arg3 : memref<32x32xf16>) { + ^bb0(%in: f16, %out: f16): + %0 = arith.addf %in, %out : f16 + linalg.yield %0 : f16 + } + linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg3 :memref<32x32xf16>) { + ^bb0(%out: f16): + %0 = arith.maximumf %out, %cst : f16 + linalg.yield %0 : f16 + } + return +} + +// CHECK-LABEL: func.func @matmul_add_relu_2D_tiled( +// CHECK-SAME: %[[A:.+]]: memref<32x32xf16>, %[[B:.+]]: memref<32x32xf16>, %[[BIAS:.+]]: memref<32x32xf16>, %[[C:.+]]: memref<32x32xf16> +// CHECK-DAG: %[[f0:.+]] = arith.constant 0.0 +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 +// CHECK-DAG: %[[c16:.+]] = arith.constant 16 +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[C]] +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[A]] +// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[B]] +// CHECK-COUNT-8: gpu.subgroup_mma_compute +// CHECK-DAG: %[[b0:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{\[}}%[[c0]], %[[c0]] +// CHECK-DAG: gpu.subgroup_mma_elementwise addf{{.*}}, %[[b0]] +// CHECK-DAG: %[[b1:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{\[}}%[[c0]], %[[c16]] +// CHECK-DAG: gpu.subgroup_mma_elementwise addf{{.*}}, %[[b1]] +// CHECK-DAG: %[[b2:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{\[}}%[[c16]], %[[c0]] +// CHECK-DAG: gpu.subgroup_mma_elementwise addf{{.*}}, %[[b2]] +// CHECK-DAG: %[[b3:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{\[}}%[[c16]], %[[c16]] +// CHECK: gpu.subgroup_mma_elementwise addf{{.*}}, %[[b3]] +// CHECK: %[[cstMat:.+]] = gpu.subgroup_mma_constant_matrix %[[f0]] +// CHECK-COUNT-4: gpu.subgroup_mma_elementwise maxf{{.*}}, %[[cstMat]] +// CHECK-COUNT-4: gpu.subgroup_mma_store_matrix{{.*}}, %[[C]]