Skip to content

Commit

Permalink
Bump LLVM
Browse files Browse the repository at this point in the history
Fixes for upstream API changes and improves tests' checks.
  • Loading branch information
adam-smnk committed Aug 19, 2024
1 parent 619fde1 commit 6d14b4f
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1846523bb77275de954ac573110171bd39bfa930
79f6ae05c139d3d5b6446f8a265a3c6e3f5b18f8
21 changes: 10 additions & 11 deletions lib/TPP/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ static std::optional<Value> lowerEltwiseOp(linalg::LinalgOp linalgOp,
// Unhandled type. Bail out.
return std::nullopt;
})
.Case([&](linalg::NegfOp negfOp) -> std::optional<Value> {
.Case([&](linalg::NegFOp negfOp) -> std::optional<Value> {
assert(operands.size() == 1 && "Invalid number of operands for negf");
return rewriter.create<arith::NegFOp>(loc, resType, operands[0])
.getResult();
Expand Down Expand Up @@ -724,17 +724,17 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,

VectorType vecLoadType =
VectorType::get(tileType.getShape(), tileType.getElementType());
IntegerAttr vnniAxisAttr = nullptr;
UnitAttr vnniPackedAttr = nullptr;
if (vnniConf) {
vnniAxisAttr = IntegerAttr::get(rewriter.getI64Type(), vnniConf->vnniAxis);
vnniPackedAttr = rewriter.getUnitAttr();
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(),
*vnniConf);
}

SmallVector<Value> loadVec;
for (auto tile : loadTiles) {
auto loadOp = rewriter.create<xegpu::LoadNdOp>(
loc, vecLoadType, tile, vnniAxisAttr, transpose,
loc, vecLoadType, tile, vnniPackedAttr, transpose,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
loadVec.push_back(loadOp);
Expand Down Expand Up @@ -1043,12 +1043,11 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
if (vnniFactor == -1)
return failure();

VnniConfig vnniConfA{.vnniFactor = vnniFactor, .vnniAxis = 1};
VnniConfig vnniConfB{.vnniFactor = vnniFactor, .vnniAxis = 0};

// Load A sub-tiles.
SmallVector<Value> loadVecA =
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA);
SmallVector<Value> loadVecA = loadNdDescTiles(
rewriter, loc, tilesA, readCacheHint, /*vnniConf=*/std::nullopt);
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());

// Load B sub-tiles.
Expand Down Expand Up @@ -1077,9 +1076,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
}

// Extract DPAS tiles from loaded sub-tiles.
TilesArray dpasVecA = extractVecSubTiles(rewriter, loc, loadVecA,
{dimM, kTile}, tileTypeA.getShape(),
{dpasTileM, dpasTileK}, vnniConfA);
TilesArray dpasVecA = extractVecSubTiles(
rewriter, loc, loadVecA, {dimM, kTile}, tileTypeA.getShape(),
{dpasTileM, dpasTileK}, /*vnniConf=*/std::nullopt);
TilesArray dpasVecB = extractVecSubTiles(rewriter, loc, loadVecB,
{kTile, dimN}, tileTypeB.getShape(),
{dpasTileK, dpasTileN}, vnniConfB);
Expand Down Expand Up @@ -1378,7 +1377,7 @@ void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns,
ConvertNamedEltwiseToXeGPU<linalg::FloorOp>,
ConvertNamedEltwiseToXeGPU<linalg::MaxOp>,
ConvertNamedEltwiseToXeGPU<linalg::MulOp>,
ConvertNamedEltwiseToXeGPU<linalg::NegfOp>,
ConvertNamedEltwiseToXeGPU<linalg::NegFOp>,
ConvertNamedEltwiseToXeGPU<linalg::SubOp>>(patterns.getContext(),
options);
}
Expand Down
6 changes: 4 additions & 2 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ struct PropagatePackUnPack
MLIRContext *ctx = getOperation().getContext();
RewritePatternSet patterns(ctx);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](Operation *op) { return true; });
patterns, [](OpOperand *operand) { return true; });
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down Expand Up @@ -813,7 +813,9 @@ void mlir::tpp::populateSimplifyPacking(RewritePatternSet &patterns) {
// Propagate packs/unpacks only through expand shapes at this point.
// This captures the transformation scope of the replaced downstream pass.
linalg::populateDataLayoutPropagationPatterns(
patterns, [](Operation *op) { return isa<tensor::ExpandShapeOp>(op); });
patterns, [](OpOperand *operand) {
return isa<tensor::ExpandShapeOp>(operand->get().getDefiningOp());
});
ctx->getLoadedDialect<tensor::TensorDialect>()->getCanonicalizationPatterns(
patterns);
patterns.add<FoldUnPackIntoInsertSlice>(ctx);
Expand Down
4 changes: 2 additions & 2 deletions test/GPU/linalg-to-xegpu-dpas.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Extract DPAS-sized chunks from larger loaded tile A.
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16>
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
// CHECK-COUNT-3: vector.extract_strided_slice

// Perform DPAS computation.
Expand Down
2 changes: 1 addition & 1 deletion test/Passes/tile-and-fuse-default.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ func.func @check_tile_propagation_to_eltwise_consumer(%arg0: tensor<2x2x2x4xf32>
// CHECK-LABEL: check_tile_propagation_to_eltwise_consumer
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x2x2x4xf32>, %[[ARG1:.+]]: tensor<2x4x8x2xf32>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<2x2x8x2xf32>, %[[ARG3:.+]]: tensor<2x2x8x2xf32>
// CHECK: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
Expand Down
2 changes: 1 addition & 1 deletion test/Passes/tile-and-fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ func.func @mlp(%arg0: tensor<8x112x32x32xbf16>, %arg1: tensor<112x112x32x32xbf16
%max = arith.maximumf %in, %cst : bf16
linalg.yield %max : bf16
} -> tensor<8x112x32x32xbf16>
// CHECK: %[[C112:.+]] = arith.constant 112 : index
// CHECK-DAG: %[[C112:.+]] = arith.constant 112 : index
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
Expand Down

0 comments on commit 6d14b4f

Please sign in to comment.