Skip to content

Commit

Permalink
Validation updated for random order of iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Dec 11, 2024
1 parent 8441ee0 commit c614f3f
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 9 deletions.
17 changes: 17 additions & 0 deletions include/TPP/Dialect/Xsmm/XsmmUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class MemRefType;
namespace func {
class CallOp;
}
namespace linalg {
class GenericOp;
struct ContractionDimensions;
} // namespace linalg

namespace xsmm {
class UnaryKindAttr;
Expand Down Expand Up @@ -122,6 +126,19 @@ func::CallOp buildXsmmCall(RewriterBase &rewriter, XsmmCallType callType,
SmallVector<XsmmOperand> operands, TypeRange results,
FlatSymbolRefAttr fnName, Operation *parentOp,
Operation *insertBefore);

std::optional<unsigned>
getPosInCodomain(unsigned dim, linalg::GenericOp linalgOp, AffineMap map);

LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter,
linalg::GenericOp linalgOp);

FailureOr<linalg::ContractionDimensions>
inferContractionDims(linalg::GenericOp genericOp);

std::optional<unsigned> getAffineBinaryOpExprIndex(AffineMap map, int index,
MLIRContext *context);

} // namespace utils
} // namespace xsmm
} // namespace mlir
Expand Down
144 changes: 142 additions & 2 deletions lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,27 @@

#include "TPP/Dialect/Xsmm/XsmmUtils.h"
#include "TPP/Dialect/Xsmm/XsmmOps.h"
#include "TPP/IR/StructuredOpMatcher.h"
#include "TPP/Transforms/Utils/BuilderUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/TypeUtilities.h"

#include "TPP/Transforms/Utils/BuilderUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "xsmm-utils"

using namespace mlir;
using namespace mlir::linalg;

namespace mlir {
namespace xsmm {
namespace utils {
Expand Down Expand Up @@ -564,6 +570,140 @@ FailureOr<SmallVector<Attribute>> getBrgemmFlags(PatternRewriter &rewriter,
return attributes;
}

static llvm::SmallDenseSet<int64_t>
findIndexingOperand(AffineMap indexingMap,
ArrayRef<mlir::utils::IteratorType> iterators,
mlir::utils::IteratorType iter) {
assert(iterators.size() == indexingMap.getNumDims());
llvm::SmallDenseSet<int64_t> res;
for (AffineExpr e : indexingMap.getResults()) {
int position = -1;
if (isa<AffineDimExpr>(e)) {
auto expr = dyn_cast<AffineDimExpr>(e);
position = expr.getPosition();
} else if (isa<AffineBinaryOpExpr>(e)) {
auto lhs = dyn_cast<AffineBinaryOpExpr>(e).getLHS();
assert(isa<AffineDimExpr>(lhs));
position = (dyn_cast<AffineDimExpr>(lhs)).getPosition();
}
assert(position >= 0);
if (iterators[position] == iter &&
llvm::count_if(indexingMap.getResults(), [position](AffineExpr e) {
return e.isFunctionOfDim(position);
}) == 1)
res.insert(position);
}
return res;
}
namespace {
auto par = mlir::utils::IteratorType::parallel;
auto red = mlir::utils::IteratorType::reduction;
} // namespace

FailureOr<linalg::ContractionDimensions>
inferContractionDims(linalg::GenericOp linalgOp) {
auto indexingMaps = linalgOp.getIndexingMapsArray();
auto iterators = linalgOp.getIteratorTypesArray();
llvm::SmallDenseSet<int64_t> a =
findIndexingOperand(indexingMaps[0], iterators, par);
llvm::SmallDenseSet<int64_t> b =
findIndexingOperand(indexingMaps[1], iterators, par);
llvm::SmallDenseSet<int64_t> c =
findIndexingOperand(indexingMaps[2], iterators, par);

// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
llvm::set_intersect(ac, c);
llvm::set_subtract(ac, b);
// B & C - A are the iterators involved in an outer-product along B (the RHS).
llvm::SmallDenseSet<int64_t> bc = b;
llvm::set_intersect(bc, c);
llvm::set_subtract(bc, a);
// A & B & C are the "batch" dimensions.
llvm::SmallDenseSet<int64_t> batches = a;
llvm::set_intersect(batches, b);
llvm::set_intersect(batches, c);

// A & B red are the reduction dimensions.
llvm::SmallDenseSet<int64_t> ra =
findIndexingOperand(indexingMaps[0], iterators, red);
llvm::SmallDenseSet<int64_t> rb =
findIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);

// Return each set in sorted order.
ContractionDimensions dimensions{
SmallVector<unsigned, 2>(batches.begin(), batches.end()),
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
llvm::sort(dimensions.m.begin(), dimensions.m.end());
llvm::sort(dimensions.n.begin(), dimensions.n.end());
llvm::sort(dimensions.k.begin(), dimensions.k.end());
return dimensions;
}

std::optional<unsigned> getAffineBinaryOpExprIndex(AffineMap map, int index,
MLIRContext *context) {
for (unsigned i = 0; i < map.getNumResults(); i++) {
auto result = map.getResult(i);
if (isa<AffineBinaryOpExpr>(result) &&
dyn_cast<AffineBinaryOpExpr>(result).getLHS() ==
getAffineDimExpr(index, context)) {
return i;
}
}
llvm_unreachable("invalid binary op index");
}

LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter,
linalg::GenericOp linalgOp) {
if (linalgOp->getNumOperands() != 3)
return failure();

if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) !=
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)) {
return failure();
}
auto iteratorTypes = linalgOp.getIteratorTypesArray();
if (iteratorTypes.size() < 4)
return failure();

auto contractionDims = inferContractionDims(linalgOp);
if (failed(contractionDims))
return failure();

unsigned m = contractionDims->m.back();
unsigned n = contractionDims->n.back();

if (!linalg::isParallelIterator(iteratorTypes[m]) ||
!linalg::isParallelIterator(iteratorTypes[n])) {
return failure();
}

if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) {
return failure();
}

auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
auto map1 = linalgOp.getIndexingMapsArray()[1];
auto index = getAffineBinaryOpExprIndex(map1, k, linalgOp.getContext());
if (!index)
return failure();

// clang-format off
using namespace mlir::structured_match;
auto hasRightOpChain =
StructuredOpMatcher::make<linalg::GenericOp>()
.region(MatchOne(0), WithOpChain<KindMul, KindAdd>(
/*captures=*/nullptr));
// clang-format on
if (!hasRightOpChain.match(linalgOp))
return failure();
return success();
}

template FailureOr<SmallVector<Attribute>>
getBrgemmFlags<xsmm::BrgemmDispatchOp>(PatternRewriter &rewriter,
xsmm::BrgemmDispatchOp dispatchOpTy,
Expand Down
15 changes: 8 additions & 7 deletions lib/TPP/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_VECTORIZATIONPASS
Expand All @@ -39,11 +38,8 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
PatternRewriter &rewriter) const override {
if (!linalgOp.hasPureBufferSemantics())
return failure();
if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) ==
xsmm::DataTypeAttr::get(rewriter.getContext(),
xsmm::DataType::BF16) &&
linalgOp.getIteratorTypes().size() >= 4 &&
linalgOp.getNumOperands() == 3) {
auto check = xsmm::utils::checkVNNIGemmStructure(rewriter, linalgOp);
if (succeeded(check)) {
SmallVector<int64_t> shape;
SmallVector<ReassociationIndices> indices;
int index = 0;
Expand Down Expand Up @@ -74,7 +70,12 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
auto map1 = linalgOp.getIndexingMapsArray()[1];
map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1),
map0.getNumResults());
int map1Index = map1.getNumResults() - 3;
auto contractionDims = xsmm::utils::inferContractionDims(linalgOp);

auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
auto map1Index = *xsmm::utils::getAffineBinaryOpExprIndex(
map1, k, linalgOp.getContext());

AffineExpr expr = map1.getResult(map1Index);
if (isa<AffineBinaryOpExpr>(expr)) {

Expand Down

0 comments on commit c614f3f

Please sign in to comment.