Skip to content

Commit

Permalink
[GPU] Add check for contractionOpInterface in setMatmulLoweringConfig (
Browse files Browse the repository at this point in the history
…iree-org#18178)

The `setMatmulLoweringConfig` function only checks that
`linalg::inferContractionDims` returns contraction dimensions, but not
that the operation is a contraction. The inferContractionDims function
does not check that maps are projected permutations, so it can infer
contraction dims on non-pure contraction ops, like convolutions. This
causes a segmentation fault because the assumption of
`setMatmulLoweringConfig` is that there the op is a pure contraction.
This fixes the bug by checking for `linalg::isaContractionOpInterface`
as well.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Aug 9, 2024
1 parent ab12a4e commit 5a48912
Showing 1 changed file with 12 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp) {
if (!linalgOp || !linalg::isaContractionOpInterface(linalgOp)) {
return failure();
}

Expand All @@ -39,23 +39,20 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
const int64_t targetSubgroupSize = target.getPreferredSubgroupSize();

SmallVector<int64_t, 4> bounds = linalgOp.getStaticLoopRanges();
FailureOr<mlir::linalg::ContractionDimensions> contractionDims =
mlir::linalg::inferContractionDims(linalgOp);
if (failed(contractionDims)) {
return failure();
}
mlir::linalg::ContractionDimensions contractionDims =
mlir::linalg::inferContractionDims(linalgOp).value();

if (contractionDims->k.empty() || contractionDims->m.empty() ||
contractionDims->n.empty()) {
if (contractionDims.k.empty() || contractionDims.m.empty() ||
contractionDims.n.empty()) {
return failure();
}

// For now we are not being smart and trying to reshape dimensions to allow
// for better usage of intrinsics, and instead are tiling all dimensions
// except the inner most m, n, and k dimensions to 1.
int64_t mDim = contractionDims->m.back();
int64_t nDim = contractionDims->n.back();
int64_t kDim = contractionDims->k.back();
int64_t mDim = contractionDims.m.back();
int64_t nDim = contractionDims.n.back();
int64_t kDim = contractionDims.k.back();

// Dynamic dims are expected to be taken care of earlier in the pipeline.
if (ShapedType::isDynamic(bounds[mDim]) ||
Expand Down Expand Up @@ -159,19 +156,19 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
SmallVector<int64_t> reductionTileSizes(linalgOp.getNumLoops(), 0);
SmallVector<int64_t> subgroupTileSizes(linalgOp.getNumLoops(), 0);
// Tile all batch dimensions with unit size.
for (int64_t batch : contractionDims->batch) {
for (int64_t batch : contractionDims.batch) {
workgroupTileSizes[batch] = 1;
}

// Tile all m, n, and k dimensions to 1 except the innermost. Unit dims
// from this tiling are folded before vectorization.
for (int64_t m : llvm::drop_end(contractionDims->m)) {
for (int64_t m : llvm::drop_end(contractionDims.m)) {
workgroupTileSizes[m] = 1;
}
for (int64_t n : llvm::drop_end(contractionDims->n)) {
for (int64_t n : llvm::drop_end(contractionDims.n)) {
workgroupTileSizes[n] = 1;
}
for (int64_t k : llvm::drop_end(contractionDims->k)) {
for (int64_t k : llvm::drop_end(contractionDims.k)) {
reductionTileSizes[k] = 1;
}

Expand Down

0 comments on commit 5a48912

Please sign in to comment.