Skip to content

Commit

Permalink
fix(compiler): Type inference rewriter: Handle return-like operations…
Browse files Browse the repository at this point in the history
… correctly

Return-like operations require special treatment by the type inference
rewriter, since their operand types are both tied to the result types
of their producers and to result types of their parent operations.

The inference scheme for ordinary operations, in which the initial
local inference state is composed of the operand types of the
rewritten producers and the old types of related operations before
rewriting is insufficient, since this may result in a mismatch between
the inferred types and the actual types of the already rewritten
parent operation.

Until now, precedence of the new result types of the parent operation
has been implemented by simply designating these types as the operand
types of a return-like operation. However, while this works as
intended for return-like operations, which simply forward values
(e.g., `func.return`), this creates invalid IR for other return-like
operations (e.g., `tensor.yield`).

This change implements precedence of the result types of the parent
operation of a return-like operation by adding the return types of the
already rewritten parent operation to the initial local inference
state before final invocation of type inference.
  • Loading branch information
andidr committed Apr 19, 2024
1 parent 46f92ec commit 064c424
Showing 1 changed file with 21 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,40 +161,32 @@ class TypeInferenceRewriter {
})) {
resolvedOperandTypes = llvm::to_vector(op->getOperandTypes());
resolvedResultTypes = llvm::to_vector(op->getResultTypes());
} else if (op->hasTrait<mlir::OpTrait::ReturnLike>()) {
// Return-like ops are a bit special, since they simply forward
// values upwards. The types inferred for the operands may come
// from producers, which are in the same block, while the result
} else {
LocalInferenceState inferredTypes =
TypeInferenceUtils::getLocalInferenceState(solver, op);

// Return-like ops are a bit special, since their operand types
// are tied to the types of their parent op. The types inferred
// for the operands of a return-like operation may come from
// producers, which are in the same block, while the result
// types of the parent op may have been deduced from producers
// or consumers of the block containing the parent operation.
//
// Blindly taking the operand types may thus result in a
// mismatch between the final return types of the parent op and
// the operand types of the return-like op.
//
// In the general case, simply take the result types of the
// parent operation, which, at this point, has already been
// partially rewritten before recursing into this rewrite call.
// Blindly applying the inferred operand types may thus result
// in a mismatch between the final return types of the parent op
// and the operand types of the return-like op.
//
// Functions are a bit different though, since the types of the
// results are contained in a function type and not in the
// result types.
mlir::Operation *newParent = mapping.lookup(op->getParentOp());

if (llvm::isa<mlir::func::ReturnOp>(op)) {
mlir::func::FuncOp newParentFunc =
llvm::dyn_cast<mlir::func::FuncOp>(newParent);
resolvedOperandTypes =
llvm::to_vector(newParentFunc.getFunctionType().getResults());
} else {
// Look up new parent op and use the return types, since these
// are the authoritative types obtained from the last invocation
// of the type resolver
resolvedOperandTypes = llvm::to_vector(newParent->getResultTypes());
// Instead, look up the rewritten parent op and add its return
// types to the local inference state before invoking type
// inference a last time.
if (op->hasTrait<mlir::OpTrait::ReturnLike>()) {
mlir::Operation *newParent = mapping.lookup(op->getParentOp());

for (auto [oldResult, newResult] : llvm::zip_equal(
op->getParentOp()->getResults(), newParent->getResults())) {
inferredTypes.set(oldResult, newResult.getType());
}
}
} else {
LocalInferenceState inferredTypes =
TypeInferenceUtils::getLocalInferenceState(solver, op);

resolvedTypes = typeResolver.resolve(op, inferredTypes);

Expand Down

0 comments on commit 064c424

Please sign in to comment.