Skip to content

Commit

Permalink
[Intel] Sync AxisInfo from upstream (#3558)
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Feb 27, 2025
1 parent cb6bfd6 commit b7604a9
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions third_party/intel/lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

Expand Down Expand Up @@ -273,6 +274,28 @@ class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
}
};

class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
public:
using AxisInfoVisitorImpl::AxisInfoVisitorImpl;

AxisInfo
getAxisInfo(ub::PoisonOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
// Poison values are never accessed, thus assume optimistic values.
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
unsigned rank = shape.getRank();
return AxisInfo(
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),
/*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2),
/*constancy=*/AxisInfo::DimVectorT(shape.getShape()));
}

return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2},
/*constancy=*/{1});
}
};

template <typename OpTy>
class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
public:
Expand Down Expand Up @@ -946,7 +969,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
lhsDivisibility = 1;
}
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
Expand Down Expand Up @@ -1092,6 +1115,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
// when scf.for supports integer induction variables
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<PoisonOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
Expand Down Expand Up @@ -1184,15 +1208,16 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
// FIXME: This is not as good as it could be, as we don't need to restrict
// the analysis to one dimension. We should determine contiguity on the
// flattenOuts() layout
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto order = linAttr.getOrder();
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
auto uniqueContigPerThread = linAttr.getContigPerThread();
assert(order[0] < uniqueContigPerThread.size() &&
"Unexpected uniqueContigPerThread size");
unsigned contiguity = uniqueContigPerThread[order[0]];
Expand All @@ -1209,8 +1234,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto *axisInfo = getAxisInfo(ptr);
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto order = linAttr.getOrder();
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
unsigned elemNumBits = isTensorPointerType(ptr.getType())
Expand Down Expand Up @@ -1239,7 +1265,9 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto maskOrder = linAttr.getOrder();
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down

0 comments on commit b7604a9

Please sign in to comment.