diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index dc0d45c25..40d75418f 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -bf684034844c660b778f0eba103582f582b710c9 +3654f1baa66f524c89e40ab24e18e594e56363e9 diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index ebab1af8d..3c2a9fb19 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -153,15 +153,17 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend})); } else { // Apply the default preprocessing pass - DefaultTppPassesOptions tppDefaultOptions; + DefaultTppPassesOptions tppDefaultOptions; tppDefaultOptions.linalgToLoops = linalgToLoops; - tppDefaultOptions.parallelTaskGrid = parallelTaskGrid; + tppDefaultOptions.parallelTaskGrid = SmallVector{ + parallelTaskGrid.begin(), parallelTaskGrid.end()}; tppDefaultOptions.linalgToVector = linalgToVector; tppDefaultOptions.vectorToXSMM = vectorToXSMM; - tppDefaultOptions.lowerPackUnpackWithoutTranspose = - lowerPackUnpackWithoutTranspose; - tppDefaultOptions.lhsTile = lhsTile; - tppDefaultOptions.rhsTile = rhsTile; + tppDefaultOptions.lowerPackUnpackWithoutTranspose = lowerPackUnpackWithoutTranspose; + tppDefaultOptions.lhsTile = + SmallVector{lhsTile.begin(), lhsTile.end()}; + tppDefaultOptions.rhsTile = + SmallVector{rhsTile.begin(), rhsTile.end()}; tppDefaultOptions.vectorToKernel = vectorToKernel; pm.addPass(createDefaultTppPasses(tppDefaultOptions)); diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index b5cc4e635..afe8a3a51 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -23,6 +23,8 @@ #include "TPP/PassUtils.h" #include "mlir/Transforms/Passes.h" +#include + using namespace mlir; using namespace mlir::tpp; @@ -136,7 +138,8 @@ struct DefaultTppPasses if (linalgToVector || forceLinalgToVector) { // Vectorizes the remaining Linalg operations pm.addNestedPass(createBrgemmLinalgTiling( - BrgemmLinalgTilingOptions{lhsTile, rhsTile})); + BrgemmLinalgTilingOptions{SmallVector{*lhsTile}, + SmallVector{*rhsTile}})); pm.addNestedPass(createLoopInvariantCodeMotionPass()); pm.addNestedPass(createVectorizationPass()); @@ -159,7 +162,8 @@ struct DefaultTppPasses // Convert forAll to parallel loops should run after bufferization // as scf.parallel does not handle tensor. pm.addPass(createConvertForAllToParallelOp()); - LowLevelParallelizationOptions LowLevelParallelization{parallelTaskGrid}; + LowLevelParallelizationOptions LowLevelParallelization{ + SmallVector{*parallelTaskGrid}}; if (linalgToVector) { pm.addPass(createConvertVectorToSCFPass()); diff --git a/lib/TPP/GPU/GpuConversion.cpp b/lib/TPP/GPU/GpuConversion.cpp index c89d39e01..806eebeab 100644 --- a/lib/TPP/GPU/GpuConversion.cpp +++ b/lib/TPP/GPU/GpuConversion.cpp @@ -65,8 +65,8 @@ struct GpuConversion : public tpp::impl::GpuConversionBase, // the default lowering for any remaining ops. pm.addNestedPass(createLinalgDeGeneralize()); if (isIntel) { - pm.addNestedPass( - createLinalgToXeGPU(LinalgToXeGPUOptions{kTile, stages, dpasTile})); + pm.addNestedPass(createLinalgToXeGPU(LinalgToXeGPUOptions{ + kTile, stages, SmallVector{*dpasTile}})); } pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addPass(createCleanup()); diff --git a/lib/TPP/GPU/GpuPipeline.cpp b/lib/TPP/GPU/GpuPipeline.cpp index 06f238aaf..4a0118fc3 100644 --- a/lib/TPP/GPU/GpuPipeline.cpp +++ b/lib/TPP/GPU/GpuPipeline.cpp @@ -172,7 +172,8 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, // First split computation into grid with blocks of specified size. TileConsumerAndFuseProducersOptions blockTileOptions; if (!llvm::any_of(gpuBlockTile, [](int64_t tile) { return tile == -1; })) - blockTileOptions.tileSizes = gpuBlockTile; + blockTileOptions.tileSizes = + SmallVector{gpuBlockTile.begin(), gpuBlockTile.end()}; blockTileOptions.minTileFactor = 1; pm.addPass(createTileConsumerAndFuseProducers(blockTileOptions)); @@ -182,7 +183,8 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, // chance for outlining. TileConsumerAndFuseProducersOptions threadTileOptions; if (!llvm::any_of(gpuThreadTile, [](int64_t tile) { return tile == -1; })) - threadTileOptions.tileSizes = gpuThreadTile; + threadTileOptions.tileSizes = + SmallVector{gpuThreadTile.begin(), gpuThreadTile.end()}; threadTileOptions.minTileFactor = 1; pm.addPass(createTileConsumerAndFuseProducers(threadTileOptions)); pm.addPass(createCleanup()); @@ -214,7 +216,8 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, // Convert to generic GPU ops. pm.addPass(createGpuConversion(GpuConversionOptions{ - gpuType == GpuType::Intel, kTile, stages, gpuDpasTile})); + gpuType == GpuType::Intel, kTile, stages, + SmallVector{gpuDpasTile.begin(), gpuDpasTile.end()}})); // Lower GPU ops to the chosen GPU backend. switch (gpuType) { diff --git a/lib/TPP/GPU/LinalgToXeGPU.cpp b/lib/TPP/GPU/LinalgToXeGPU.cpp index 88b88c734..243a9174d 100644 --- a/lib/TPP/GPU/LinalgToXeGPU.cpp +++ b/lib/TPP/GPU/LinalgToXeGPU.cpp @@ -512,8 +512,9 @@ createGemmCoopPrefetchTile(PatternRewriter &rewriter, linalg::LinalgOp linalgOp, auto srcType = cast(src.getType()); - auto prefetchType = - xegpu::TensorDescType::get({numRows, numCols}, srcType.getElementType()); + auto prefetchType = xegpu::TensorDescType::get( + {numRows, numCols}, srcType.getElementType(), /*array_length=*/1, + /*boundary_check=*/true); Value threadId = getGpuLinearThreadId(rewriter, loc); @@ -620,7 +621,9 @@ static SmallVector createDescriptorTiles(PatternRewriter &rewriter, assert(arrayLength == 1 && "Array descriptors are not supported"); auto type = cast(src.getType()); - auto descType = xegpu::TensorDescType::get(descTile, type.getElementType()); + auto descType = xegpu::TensorDescType::get(descTile, type.getElementType(), + /*array_length=*/1, + /*boundary_check=*/true); // Create the root descriptor. // @@ -868,8 +871,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, int dimK = typeA.getShape().back(); // Create C sub-tiles. - auto dpasTypeC = xegpu::TensorDescType::get({dpasTileM, dpasTileN}, - typeC.getElementType()); + auto dpasTypeC = xegpu::TensorDescType::get( + {dpasTileM, dpasTileN}, typeC.getElementType(), /*array_length=*/1, + /*boundary_check=*/true); SmallVector tilesC = createDescriptorTiles( rewriter, loc, matC, typeC.getShape(), {0, 0}, dpasTypeC.getShape()); @@ -1385,7 +1389,10 @@ struct LinalgToXeGPU : public tpp::impl::LinalgToXeGPUBase { using LinalgToXeGPUBase::LinalgToXeGPUBase; void runOnOperation() override { - LinalgToXeGPUOptions options{kTile, stages, dpasTile}; + LinalgToXeGPUOptions options; + options.kTile = kTile; + options.stages = stages; + options.dpasTile = SmallVector{*dpasTile}; // Run GEMM pattern first to allow fusion with its consumers. RewritePatternSet gemmPatterns(&getContext()); diff --git a/lib/TPP/PassBundles/LinalgLowering.cpp b/lib/TPP/PassBundles/LinalgLowering.cpp index ae1d94c4b..938159a40 100644 --- a/lib/TPP/PassBundles/LinalgLowering.cpp +++ b/lib/TPP/PassBundles/LinalgLowering.cpp @@ -18,6 +18,8 @@ #include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/PassUtils.h" +#include + using namespace mlir; using namespace mlir::tpp; @@ -48,7 +50,7 @@ struct LinalgLowering : public tpp::impl::LinalgLoweringBase, private: void constructPipeline() override { ConvertLinalgToXsmmOptions linalgOptions; - linalgOptions.skipOperations = skipOperations; + linalgOptions.skipOperations = SmallVector{*skipOperations}; pm.addPass(createConvertLinalgToXsmm(linalgOptions)); pm.addPass(createCombineXsmmOpPass()); pm.addPass(createFoldXsmmFlags()); diff --git a/lib/TPP/PassBundles/LowLevelParallelization.cpp b/lib/TPP/PassBundles/LowLevelParallelization.cpp index b8f5de694..ace149cfa 100644 --- a/lib/TPP/PassBundles/LowLevelParallelization.cpp +++ b/lib/TPP/PassBundles/LowLevelParallelization.cpp @@ -63,7 +63,7 @@ struct LowLevelParallelization pm.addPass(createCleanup()); mlir::tpp::SCFParallelLoopTilingOptions tilingOptions; - tilingOptions.tileSizes = parallelTaskGrid; + tilingOptions.tileSizes = SmallVector{*parallelTaskGrid}; pm.addPass(createSCFParallelLoopTiling(tilingOptions)); } }; diff --git a/lib/TPP/Runner/MLIRBench.cpp b/lib/TPP/Runner/MLIRBench.cpp index 343942514..dead56af8 100644 --- a/lib/TPP/Runner/MLIRBench.cpp +++ b/lib/TPP/Runner/MLIRBench.cpp @@ -351,8 +351,8 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) { outerDim = outputType.getShape()[0]; // Vector undefined value - Value minusOne = builder.create( - unkLoc, getTypedAttr(builder, outElmType, -1.0)); + Value undefLengthCst = builder.create( + unkLoc, getTypedAttr(builder, outElmType, 0.0)); // Loop through the shaped type, transfer each dim to vector auto count = getConstIndex(builder, outerDim); @@ -364,7 +364,7 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) { // Loop body auto beginIdx = loop.getInductionVar(); auto vector = builder.create( - unkLoc, vecType, val, ValueRange{beginIdx, zero}, minusOne); + unkLoc, vecType, val, ValueRange{beginIdx, zero}, undefLengthCst); printVector(vector); // Finally lower to LLVM Dialect diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index 57300ed5d..2cc74504b 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -223,7 +223,9 @@ struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase{*mTileShape}; + options.nTileShape = SmallVector{*nTileShape}; RewritePatternSet patterns(&getContext()); populateBrgemmLinalgTilingPatterns(patterns, options); GreedyRewriteConfig config; diff --git a/lib/TPP/Transforms/Bufferize.cpp b/lib/TPP/Transforms/Bufferize.cpp index de6325b3f..96c65a953 100644 --- a/lib/TPP/Transforms/Bufferize.cpp +++ b/lib/TPP/Transforms/Bufferize.cpp @@ -142,8 +142,6 @@ void Bufferize::runOnOperation() { if (!runOnlyAnalysis) { passManager.addPass(bufferization::createDropEquivalentBufferResultsPass()); - passManager.addNestedPass( - bufferization::createFinalizingBufferizePass()); // Post-processing. passManager.addNestedPass(createCanonicalizerPass()); diff --git a/lib/TPP/Transforms/LowerPacksAndUnpacks.cpp b/lib/TPP/Transforms/LowerPacksAndUnpacks.cpp index c1a3caec6..cc135512f 100644 --- a/lib/TPP/Transforms/LowerPacksAndUnpacks.cpp +++ b/lib/TPP/Transforms/LowerPacksAndUnpacks.cpp @@ -229,8 +229,8 @@ class LowerPacksAndUnPacks rewriter.replaceOp(packOp, tilingResult->replacements); }); RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { diff --git a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp index bc46c323d..540a10b22 100644 --- a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp +++ b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp @@ -20,6 +20,8 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" + +#include #include using namespace mlir; @@ -421,15 +423,18 @@ static FailureOr fuseWithEltwise( tileAndFuseOptions.setTilingOptions(options); scf::SCFTileAndFuseOptions::ControlFnTy controlFn = [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, - bool isDestinationOperand) { - Operation *candidateOp = originalProducer.getOwner(); - if (!candidateOp || worklist.count(candidateOp) == 0 || - (alreadyFusedOps.count(candidateOp) && - !isa(candidateOp))) { - return std::make_tuple(false, false); - } - return std::make_tuple(true, false); - }; + bool isDestinationOperand) + -> std::optional { + Operation *candidateOp = originalProducer.getOwner(); + if (!candidateOp || worklist.count(candidateOp) == 0 || + (alreadyFusedOps.count(candidateOp) && + !isa(candidateOp))) { + return std::nullopt; + } + scf::SCFTileAndFuseOptions::ControlFnResult res; + res.yieldProducerReplacement = false; + return res; + }; tileAndFuseOptions.setFusionControlFn(controlFn); FailureOr tileAndFuseResult = scf::tileConsumerAndFuseProducersUsingSCF(rewriter, consumer, diff --git a/test/Dialect/Xsmm/xsmm-invalid.mlir b/test/Dialect/Xsmm/xsmm-invalid.mlir index 58b7858ad..f46fe648b 100644 --- a/test/Dialect/Xsmm/xsmm-invalid.mlir +++ b/test/Dialect/Xsmm/xsmm-invalid.mlir @@ -419,7 +419,7 @@ func.func @binary_invoke(%arg1: memref<3x3xf32>) { // ----- func.func @binary_invoke(%arg0: i64, %arg1: memref<3x3xf32>) { - // expected-error@+1 {{operands present, but expected 5}} + // expected-error@+1 {{custom op 'xsmm.binary' number of operands and types do not match: got 6 operands and 5 types}} xsmm.binary add(data_type = f32, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1) : (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () return diff --git a/test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir b/test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir index da11149d6..75f875037 100644 --- a/test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir +++ b/test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir @@ -13,7 +13,7 @@ #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> module attributes { "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" - : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> + = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> } { func.func @entry(%arg0: tensor<8x8xf32> {bufferization.writable = true}, %arg1: tensor<8x8xf32> {bufferization.writable = true}, diff --git a/test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir b/test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir index 4d3f847f6..192b4d1b0 100644 --- a/test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir +++ b/test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir @@ -14,15 +14,15 @@ module { // Kernel arguments are allocated on host // Copy data to device %0, %t0 = gpu.alloc async () : memref<8x8xf32> - %a0 = bufferization.to_memref %arg0 : memref<8x8xf32> + %a0 = bufferization.to_memref %arg0 : tensor<8x8xf32> to memref<8x8xf32> %t1 = gpu.memcpy async [%t0] %0, %a0 : memref<8x8xf32>, memref<8x8xf32> gpu.wait [%t1] %1, %t2 = gpu.alloc async () : memref<8x8xf32> - %a1 = bufferization.to_memref %arg1 : memref<8x8xf32> + %a1 = bufferization.to_memref %arg1 : tensor<8x8xf32> to memref<8x8xf32> %t3 = gpu.memcpy async [%t2] %1, %a1 : memref<8x8xf32>, memref<8x8xf32> gpu.wait [%t3] %2, %t4 = gpu.alloc async () : memref<8x8xf32> - %a2 = bufferization.to_memref %arg2 : memref<8x8xf32> + %a2 = bufferization.to_memref %arg2 : tensor<8x8xf32> to memref<8x8xf32> %t5 = gpu.memcpy async [%t4] %2, %a2 : memref<8x8xf32>, memref<8x8xf32> gpu.wait [%t5] @@ -46,7 +46,7 @@ module { %tD2 = gpu.dealloc async %2 : memref<8x8xf32> gpu.wait [%tD2] - %outTensor = bufferization.to_tensor %out restrict : memref<8x8xf32> + %outTensor = bufferization.to_tensor %out restrict : memref<8x8xf32> to tensor<8x8xf32> return %outTensor : tensor<8x8xf32> } diff --git a/test/GPU/CUDA/Integration/linalg-mlp.mlir b/test/GPU/CUDA/Integration/linalg-mlp.mlir index d9f40af93..ed6ff3465 100644 --- a/test/GPU/CUDA/Integration/linalg-mlp.mlir +++ b/test/GPU/CUDA/Integration/linalg-mlp.mlir @@ -9,7 +9,7 @@ #map3 = affine_map<(d0, d1) -> (d0, d1)> module attributes { "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" - : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> + = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> } { func.func @entry(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { %weights = arith.constant dense<0.1> : tensor<8x8xf32> diff --git a/test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir b/test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir index 83bd2247e..cbaa83c39 100644 --- a/test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir +++ b/test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir @@ -9,7 +9,7 @@ // This requires either GPU unified memory or explicit data transfers to GPU. module attributes { "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" - : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> + = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> } { func.func @entry(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { %c0 = arith.constant 0.0 : f32 diff --git a/test/Passes/tile-and-fuse-default.mlir b/test/Passes/tile-and-fuse-default.mlir index bb4f9dfd9..53a21e003 100644 --- a/test/Passes/tile-and-fuse-default.mlir +++ b/test/Passes/tile-and-fuse-default.mlir @@ -733,7 +733,7 @@ func.func @contraction(%arg0: tensor<16x1xf32>, %arg1: tensor<1x32xf32>) -> tens // ----- // CHECK-LABEL: dlti_tile_size_32 -module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>> } { +module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>> } { func.func @dlti_tile_size_32(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>) -> tensor<2048x2048xf32> { // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index @@ -751,7 +751,7 @@ module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : # // ----- // CHECK-LABEL: dlti_tile_size_64 -module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 64 : i32>>> } { +module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 64 : i32>>> } { func.func @dlti_tile_size_64(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>) -> tensor<2048x2048xf32> { // CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index @@ -770,7 +770,7 @@ module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : # // ----- // CHECK-LABEL: dlti_tile_size_16 -module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 16 : i32>>> } { +module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 16 : i32>>> } { func.func @dlti_tile_size_16(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>) -> tensor<2048x2048xf32> { // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index