Skip to content

Commit

Permalink
[DispatchCreation] Changes to dispatch region in preparation for hori…
Browse files Browse the repository at this point in the history
…zontal fusion changes.

This commit is mirror of #19876

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored and Max191 committed Feb 10, 2025
1 parent 52a2a39 commit 22cdf97
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void TensorDimTrackingRewriter::notifyOperationErased(Operation *op) {
void TensorDimTrackingRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
IRRewriter::Listener::notifyOperationInserted(op, previous);
if (isa<tensor::DimOp>(op))
auto dimOp = dyn_cast<tensor::DimOp>(op);
if (dimOp && isa<OpResult>(dimOp.getSource()))
dimOps.insert(op);
}

Expand All @@ -60,16 +61,21 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter,
std::optional<int64_t> idx = dimOp.getConstantIndex();
if (!idx.has_value())
continue;

if (isa<BlockArgument>(dimOp.getSource())) {
continue;
}

// Only DimOps with ranked tensors are supported.
auto tensorType =
llvm::dyn_cast<RankedTensorType>(dimOp.getSource().getType());
if (!tensorType)
continue;

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(dimOp);
if (!tensorType.isDynamicDim(*idx)) {
// Rewrite static dimension with constant.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(dimOp);
int64_t size = tensorType.getShape()[*idx];
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(dimOp, size);
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,8 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
// Value is an OpResult.
Operation *op = value.getDefiningOp();
OpResult opResult = llvm::cast<OpResult>(value);
b.setInsertionPoint(op);

// Case 3: Value is tied. Reify the dimensions of the tied operand.
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
if (tiedOp) {
Value tiedOperand = tiedOp.getTiedResultOperand(value);
if (tiedOperand && tiedOperand.getType() == value.getType())
return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims,
createTensorDimOps);
}

// Case 4: Query ShapeAwareOpInterface.
// Case 3: Query ShapeAwareOpInterface.
auto shapeAwareOp = dyn_cast<IREE::Util::ShapeAwareOpInterface>(op);
if (shapeAwareOp) {
ValueRange dims =
Expand All @@ -286,6 +276,15 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
return success();
}

// Case 4: Value is tied. Reify the dimensions of the tied operand.
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
if (tiedOp) {
Value tiedOperand = tiedOp.getTiedResultOperand(value);
if (tiedOperand && tiedOperand.getType() == value.getType())
return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims,
/*createTensorDimOps=*/true);
}

// Case 5: Query ReifyRankedShapedTypeOpInterface.
auto reifyShapeOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
if (reifyShapeOp) {
Expand All @@ -308,8 +307,14 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
}

/// Reify the dynamic dimensions of the given value.
/// Deprecated. Use `getOptimizedDynamicResultDims` instead.
LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value,
SmallVectorImpl<Value> &dynamicDims) {

OpBuilder::InsertionGuard g(b);
if (auto op = value.getDefiningOp()) {
b.setInsertionPoint(op);
}
return reifyDynamicResultDimsImpl(b, value, dynamicDims,
/*createTensorDimOps=*/true);
}
Expand Down Expand Up @@ -473,7 +478,7 @@ movePrecedingOpsIntoDispatchRegion(RewriterBase &rewriter,
rewriter.setInsertionPoint(target);
SmallVector<Value> &dims =
dispatchOpNewResultsDynamicDims.emplace_back();
if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) {
return target->emitOpError(
"failed to reify dynamic dims of result to be yielded from "
"dispatch region");
Expand Down Expand Up @@ -554,9 +559,10 @@ moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target,
for (auto [index, result] : llvm::enumerate(target->getResults())) {
replacedValues.push_back(result);
yieldedResults.push_back(clonedTarget->getResult(index));
rewriter.setInsertionPoint(target);
OpBuilder::InsertionGuard g1(rewriter);
rewriter.setInsertionPoint(regionOp);
SmallVector<Value> &dims = dispatchOpNewResultsDynamicDims.emplace_back();
if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) {
return target->emitOpError(
"failed to reify dynamic dims of result to be yielded from "
"dispatch region");
Expand Down
Loading

0 comments on commit 22cdf97

Please sign in to comment.