Skip to content

Commit

Permalink
Added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Dec 11, 2024
1 parent c614f3f commit a2bb306
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,21 +677,25 @@ LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter,
unsigned m = contractionDims->m.back();
unsigned n = contractionDims->n.back();

// m and n dimensions must be parallel dimensions
if (!linalg::isParallelIterator(iteratorTypes[m]) ||
!linalg::isParallelIterator(iteratorTypes[n])) {
return failure();
}

// innermost dimension must be a reduction dimension for VNNI type operations
if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) {
return failure();
}

// get the index of the iterator corresponding to the floordiv operation
auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
auto map1 = linalgOp.getIndexingMapsArray()[1];
auto index = getAffineBinaryOpExprIndex(map1, k, linalgOp.getContext());
if (!index)
return failure();

// Ensure that the body of the generic operation is mul-add chain
// clang-format off
using namespace mlir::structured_match;
auto hasRightOpChain =
Expand Down
4 changes: 4 additions & 0 deletions lib/TPP/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
}
auto map0 = linalgOp.getIndexingMapsArray()[0];
auto map1 = linalgOp.getIndexingMapsArray()[1];
// Set the innermost dimension of the first map to vnni dimension
map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1),
map0.getNumResults());
auto contractionDims = xsmm::utils::inferContractionDims(linalgOp);

auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
// Get the index of the iterator corresponding to the floordiv operation
auto map1Index = *xsmm::utils::getAffineBinaryOpExprIndex(
map1, k, linalgOp.getContext());

Expand All @@ -82,6 +84,8 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
auto expand = rewriter.create<memref::ExpandShapeOp>(
linalgOp.getLoc(), shape, linalgOp.getOperand(0), indices);
linalgOp.setOperand(0, expand.getResult());
// Replace the floordiv operation with just the LHS of the floordiv
// expression
map1 = map1.insertResult(
dyn_cast<AffineBinaryOpExpr>(map1.getResult(map1Index)).getLHS(),
map1Index + 1);
Expand Down

0 comments on commit a2bb306

Please sign in to comment.