Skip to content

Commit

Permalink
Bump LLVM (#988)
Browse files Browse the repository at this point in the history
Fixes after bump:
  - Update IR syntax
  - Resolve broken implicit conversions when initializing pass options
  - Remove deprecated partial bufferization
  - Fix APInt ctor assertion - use zero for undefined length vector read
  • Loading branch information
adam-smnk authored Dec 10, 2024
1 parent f07997a commit d7fcc28
Show file tree
Hide file tree
Showing 19 changed files with 73 additions and 50 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
bf684034844c660b778f0eba103582f582b710c9
3654f1baa66f524c89e40ab24e18e594e56363e9
14 changes: 8 additions & 6 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,17 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend}));
} else {
// Apply the default preprocessing pass
DefaultTppPassesOptions tppDefaultOptions;
DefaultTppPassesOptions tppDefaultOptions;
tppDefaultOptions.linalgToLoops = linalgToLoops;
tppDefaultOptions.parallelTaskGrid = parallelTaskGrid;
tppDefaultOptions.parallelTaskGrid = SmallVector<unsigned>{
parallelTaskGrid.begin(), parallelTaskGrid.end()};
tppDefaultOptions.linalgToVector = linalgToVector;
tppDefaultOptions.vectorToXSMM = vectorToXSMM;
tppDefaultOptions.lowerPackUnpackWithoutTranspose =
lowerPackUnpackWithoutTranspose;
tppDefaultOptions.lhsTile = lhsTile;
tppDefaultOptions.rhsTile = rhsTile;
tppDefaultOptions.lowerPackUnpackWithoutTranspose = lowerPackUnpackWithoutTranspose;
tppDefaultOptions.lhsTile =
SmallVector<unsigned>{lhsTile.begin(), lhsTile.end()};
tppDefaultOptions.rhsTile =
SmallVector<unsigned>{rhsTile.begin(), rhsTile.end()};
tppDefaultOptions.vectorToKernel = vectorToKernel;

pm.addPass(createDefaultTppPasses(tppDefaultOptions));
Expand Down
8 changes: 6 additions & 2 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "TPP/PassUtils.h"
#include "mlir/Transforms/Passes.h"

#include <string>

using namespace mlir;
using namespace mlir::tpp;

Expand Down Expand Up @@ -136,7 +138,8 @@ struct DefaultTppPasses
if (linalgToVector || forceLinalgToVector) {
// Vectorizes the remaining Linalg operations
pm.addNestedPass<func::FuncOp>(createBrgemmLinalgTiling(
BrgemmLinalgTilingOptions{lhsTile, rhsTile}));
BrgemmLinalgTilingOptions{SmallVector<unsigned>{*lhsTile},
SmallVector<unsigned>{*rhsTile}}));
pm.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
pm.addNestedPass<func::FuncOp>(createVectorizationPass());

Expand All @@ -159,7 +162,8 @@ struct DefaultTppPasses
// Convert forAll to parallel loops should run after bufferization
// as scf.parallel does not handle tensor.
pm.addPass(createConvertForAllToParallelOp());
LowLevelParallelizationOptions LowLevelParallelization{parallelTaskGrid};
LowLevelParallelizationOptions LowLevelParallelization{
SmallVector<unsigned>{*parallelTaskGrid}};

if (linalgToVector) {
pm.addPass(createConvertVectorToSCFPass());
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/GPU/GpuConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ struct GpuConversion : public tpp::impl::GpuConversionBase<GpuConversion>,
// the default lowering for any remaining ops.
pm.addNestedPass<func::FuncOp>(createLinalgDeGeneralize());
if (isIntel) {
pm.addNestedPass<func::FuncOp>(
createLinalgToXeGPU(LinalgToXeGPUOptions{kTile, stages, dpasTile}));
pm.addNestedPass<func::FuncOp>(createLinalgToXeGPU(LinalgToXeGPUOptions{
kTile, stages, SmallVector<int64_t>{*dpasTile}}));
}
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
pm.addPass(createCleanup());
Expand Down
9 changes: 6 additions & 3 deletions lib/TPP/GPU/GpuPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase<GpuPipeline>,
// First split computation into grid with blocks of specified size.
TileConsumerAndFuseProducersOptions blockTileOptions;
if (!llvm::any_of(gpuBlockTile, [](int64_t tile) { return tile == -1; }))
blockTileOptions.tileSizes = gpuBlockTile;
blockTileOptions.tileSizes =
SmallVector<int64_t>{gpuBlockTile.begin(), gpuBlockTile.end()};
blockTileOptions.minTileFactor = 1;
pm.addPass(createTileConsumerAndFuseProducers(blockTileOptions));

Expand All @@ -182,7 +183,8 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase<GpuPipeline>,
// chance for outlining.
TileConsumerAndFuseProducersOptions threadTileOptions;
if (!llvm::any_of(gpuThreadTile, [](int64_t tile) { return tile == -1; }))
threadTileOptions.tileSizes = gpuThreadTile;
threadTileOptions.tileSizes =
SmallVector<int64_t>{gpuThreadTile.begin(), gpuThreadTile.end()};
threadTileOptions.minTileFactor = 1;
pm.addPass(createTileConsumerAndFuseProducers(threadTileOptions));
pm.addPass(createCleanup());
Expand Down Expand Up @@ -214,7 +216,8 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase<GpuPipeline>,

// Convert to generic GPU ops.
pm.addPass(createGpuConversion(GpuConversionOptions{
gpuType == GpuType::Intel, kTile, stages, gpuDpasTile}));
gpuType == GpuType::Intel, kTile, stages,
SmallVector<int64_t>{gpuDpasTile.begin(), gpuDpasTile.end()}}));

// Lower GPU ops to the chosen GPU backend.
switch (gpuType) {
Expand Down
19 changes: 13 additions & 6 deletions lib/TPP/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,9 @@ createGemmCoopPrefetchTile(PatternRewriter &rewriter, linalg::LinalgOp linalgOp,

auto srcType = cast<ShapedType>(src.getType());

auto prefetchType =
xegpu::TensorDescType::get({numRows, numCols}, srcType.getElementType());
auto prefetchType = xegpu::TensorDescType::get(
{numRows, numCols}, srcType.getElementType(), /*array_length=*/1,
/*boundary_check=*/true);

Value threadId = getGpuLinearThreadId(rewriter, loc);

Expand Down Expand Up @@ -620,7 +621,9 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
assert(arrayLength == 1 && "Array descriptors are not supported");

auto type = cast<ShapedType>(src.getType());
auto descType = xegpu::TensorDescType::get(descTile, type.getElementType());
auto descType = xegpu::TensorDescType::get(descTile, type.getElementType(),
/*array_length=*/1,
/*boundary_check=*/true);

// Create the root descriptor.
//
Expand Down Expand Up @@ -868,8 +871,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
int dimK = typeA.getShape().back();

// Create C sub-tiles.
auto dpasTypeC = xegpu::TensorDescType::get({dpasTileM, dpasTileN},
typeC.getElementType());
auto dpasTypeC = xegpu::TensorDescType::get(
{dpasTileM, dpasTileN}, typeC.getElementType(), /*array_length=*/1,
/*boundary_check=*/true);
SmallVector<Value> tilesC = createDescriptorTiles(
rewriter, loc, matC, typeC.getShape(), {0, 0}, dpasTypeC.getShape());

Expand Down Expand Up @@ -1385,7 +1389,10 @@ struct LinalgToXeGPU : public tpp::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
using LinalgToXeGPUBase::LinalgToXeGPUBase;

void runOnOperation() override {
LinalgToXeGPUOptions options{kTile, stages, dpasTile};
LinalgToXeGPUOptions options;
options.kTile = kTile;
options.stages = stages;
options.dpasTile = SmallVector<int64_t>{*dpasTile};

// Run GEMM pattern first to allow fusion with its consumers.
RewritePatternSet gemmPatterns(&getContext());
Expand Down
4 changes: 3 additions & 1 deletion lib/TPP/PassBundles/LinalgLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "TPP/Dialect/Xsmm/XsmmDialect.h"
#include "TPP/PassUtils.h"

#include <string>

using namespace mlir;
using namespace mlir::tpp;

Expand Down Expand Up @@ -48,7 +50,7 @@ struct LinalgLowering : public tpp::impl::LinalgLoweringBase<LinalgLowering>,
private:
void constructPipeline() override {
ConvertLinalgToXsmmOptions linalgOptions;
linalgOptions.skipOperations = skipOperations;
linalgOptions.skipOperations = SmallVector<std::string>{*skipOperations};
pm.addPass(createConvertLinalgToXsmm(linalgOptions));
pm.addPass(createCombineXsmmOpPass());
pm.addPass(createFoldXsmmFlags());
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/PassBundles/LowLevelParallelization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct LowLevelParallelization
pm.addPass(createCleanup());

mlir::tpp::SCFParallelLoopTilingOptions tilingOptions;
tilingOptions.tileSizes = parallelTaskGrid;
tilingOptions.tileSizes = SmallVector<unsigned>{*parallelTaskGrid};
pm.addPass(createSCFParallelLoopTiling(tilingOptions));
}
};
6 changes: 3 additions & 3 deletions lib/TPP/Runner/MLIRBench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) {
outerDim = outputType.getShape()[0];

// Vector undefined value
Value minusOne = builder.create<arith::ConstantOp>(
unkLoc, getTypedAttr(builder, outElmType, -1.0));
Value undefLengthCst = builder.create<arith::ConstantOp>(
unkLoc, getTypedAttr(builder, outElmType, 0.0));

// Loop through the shaped type, transfer each dim to vector
auto count = getConstIndex(builder, outerDim);
Expand All @@ -364,7 +364,7 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) {
// Loop body
auto beginIdx = loop.getInductionVar();
auto vector = builder.create<vector::TransferReadOp>(
unkLoc, vecType, val, ValueRange{beginIdx, zero}, minusOne);
unkLoc, vecType, val, ValueRange{beginIdx, zero}, undefLengthCst);
printVector(vector);

// Finally lower to LLVM Dialect
Expand Down
4 changes: 3 additions & 1 deletion lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinal
using BrgemmLinalgTilingBase::BrgemmLinalgTilingBase;

void runOnOperation() override {
BrgemmLinalgTilingOptions options{mTileShape, nTileShape};
BrgemmLinalgTilingOptions options;
options.mTileShape = SmallVector<unsigned>{*mTileShape};
options.nTileShape = SmallVector<unsigned>{*nTileShape};
RewritePatternSet patterns(&getContext());
populateBrgemmLinalgTilingPatterns(patterns, options);
GreedyRewriteConfig config;
Expand Down
2 changes: 0 additions & 2 deletions lib/TPP/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ void Bufferize::runOnOperation() {

if (!runOnlyAnalysis) {
passManager.addPass(bufferization::createDropEquivalentBufferResultsPass());
passManager.addNestedPass<func::FuncOp>(
bufferization::createFinalizingBufferizePass());

// Post-processing.
passManager.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/LowerPacksAndUnpacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ class LowerPacksAndUnPacks
rewriter.replaceOp(packOp, tilingResult->replacements);
});
RewritePatternSet patterns(&getContext());
patterns.add<linalg::GeneralizeOuterUnitDimsUnPackOpPattern,
linalg::GeneralizeOuterUnitDimsPackOpPattern>(&getContext());
patterns.add<linalg::DecomposeOuterUnitDimsUnPackOpPattern,
linalg::DecomposeOuterUnitDimsPackOpPattern>(&getContext());
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
Expand Down
23 changes: 14 additions & 9 deletions lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

#include <optional>
#include <queue>

using namespace mlir;
Expand Down Expand Up @@ -421,15 +423,18 @@ static FailureOr<scf::SCFTileAndFuseResult> fuseWithEltwise(
tileAndFuseOptions.setTilingOptions(options);
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand) {
Operation *candidateOp = originalProducer.getOwner();
if (!candidateOp || worklist.count(candidateOp) == 0 ||
(alreadyFusedOps.count(candidateOp) &&
!isa<linalg::FillOp>(candidateOp))) {
return std::make_tuple(false, false);
}
return std::make_tuple(true, false);
};
bool isDestinationOperand)
-> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
Operation *candidateOp = originalProducer.getOwner();
if (!candidateOp || worklist.count(candidateOp) == 0 ||
(alreadyFusedOps.count(candidateOp) &&
!isa<linalg::FillOp>(candidateOp))) {
return std::nullopt;
}
scf::SCFTileAndFuseOptions::ControlFnResult res;
res.yieldProducerReplacement = false;
return res;
};
tileAndFuseOptions.setFusionControlFn(controlFn);
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, consumer,
Expand Down
2 changes: 1 addition & 1 deletion test/Dialect/Xsmm/xsmm-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func.func @binary_invoke(%arg1: memref<3x3xf32>) {
// -----

func.func @binary_invoke(%arg0: i64, %arg1: memref<3x3xf32>) {
// expected-error@+1 {{operands present, but expected 5}}
// expected-error@+1 {{custom op 'xsmm.binary' number of operands and types do not match: got 6 operands and 5 types}}
xsmm.binary add(data_type = f32, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1)
: (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> ()
return
Expand Down
2 changes: 1 addition & 1 deletion test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module attributes {
"#dlti.sys_spec" = #dlti.target_system_spec<"CPU"
: #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>>
= #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>>
} {
func.func @entry(%arg0: tensor<8x8xf32> {bufferization.writable = true},
%arg1: tensor<8x8xf32> {bufferization.writable = true},
Expand Down
8 changes: 4 additions & 4 deletions test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ module {
// Kernel arguments are allocated on host
// Copy data to device
%0, %t0 = gpu.alloc async () : memref<8x8xf32>
%a0 = bufferization.to_memref %arg0 : memref<8x8xf32>
%a0 = bufferization.to_memref %arg0 : tensor<8x8xf32> to memref<8x8xf32>
%t1 = gpu.memcpy async [%t0] %0, %a0 : memref<8x8xf32>, memref<8x8xf32>
gpu.wait [%t1]
%1, %t2 = gpu.alloc async () : memref<8x8xf32>
%a1 = bufferization.to_memref %arg1 : memref<8x8xf32>
%a1 = bufferization.to_memref %arg1 : tensor<8x8xf32> to memref<8x8xf32>
%t3 = gpu.memcpy async [%t2] %1, %a1 : memref<8x8xf32>, memref<8x8xf32>
gpu.wait [%t3]
%2, %t4 = gpu.alloc async () : memref<8x8xf32>
%a2 = bufferization.to_memref %arg2 : memref<8x8xf32>
%a2 = bufferization.to_memref %arg2 : tensor<8x8xf32> to memref<8x8xf32>
%t5 = gpu.memcpy async [%t4] %2, %a2 : memref<8x8xf32>, memref<8x8xf32>
gpu.wait [%t5]

Expand All @@ -46,7 +46,7 @@ module {
%tD2 = gpu.dealloc async %2 : memref<8x8xf32>
gpu.wait [%tD2]

%outTensor = bufferization.to_tensor %out restrict : memref<8x8xf32>
%outTensor = bufferization.to_tensor %out restrict : memref<8x8xf32> to tensor<8x8xf32>

return %outTensor : tensor<8x8xf32>
}
Expand Down
2 changes: 1 addition & 1 deletion test/GPU/CUDA/Integration/linalg-mlp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#map3 = affine_map<(d0, d1) -> (d0, d1)>
module attributes {
"#dlti.sys_spec" = #dlti.target_system_spec<"CPU"
: #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>>
= #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>>
} {
func.func @entry(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> {
%weights = arith.constant dense<0.1> : tensor<8x8xf32>
Expand Down
2 changes: 1 addition & 1 deletion test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// This requires either GPU unified memory or explicit data transfers to GPU.
module attributes {
"#dlti.sys_spec" = #dlti.target_system_spec<"CPU"
: #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>>
= #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>>
} {
func.func @entry(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
%c0 = arith.constant 0.0 : f32
Expand Down
6 changes: 3 additions & 3 deletions test/Passes/tile-and-fuse-default.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ func.func @contraction(%arg0: tensor<16x1xf32>, %arg1: tensor<1x32xf32>) -> tens
// -----

// CHECK-LABEL: dlti_tile_size_32
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>> } {
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>> } {
func.func @dlti_tile_size_32(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32> {
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
Expand All @@ -751,7 +751,7 @@ module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #
// -----

// CHECK-LABEL: dlti_tile_size_64
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 64 : i32>>> } {
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 64 : i32>>> } {
func.func @dlti_tile_size_64(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32> {
// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
Expand All @@ -770,7 +770,7 @@ module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #
// -----

// CHECK-LABEL: dlti_tile_size_16
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 16 : i32>>> } {
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 16 : i32>>> } {
func.func @dlti_tile_size_16(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32> {
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
Expand Down

0 comments on commit d7fcc28

Please sign in to comment.