Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(compiler): Type inference rewriter: Handle return-like operations correctly #796

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,40 +161,32 @@ class TypeInferenceRewriter {
})) {
resolvedOperandTypes = llvm::to_vector(op->getOperandTypes());
resolvedResultTypes = llvm::to_vector(op->getResultTypes());
} else if (op->hasTrait<mlir::OpTrait::ReturnLike>()) {
// Return-like ops are a bit special, since they simply forward
// values upwards. The types inferred for the operands may come
// from producers, which are in the same block, while the result
} else {
LocalInferenceState inferredTypes =
TypeInferenceUtils::getLocalInferenceState(solver, op);

// Return-like ops are a bit special, since their operand types
// are tied to the types of their parent op. The types inferred
// for the operands of a return-like operation may come from
// producers, which are in the same block, while the result
// types of the parent op may have been deduced from producers
// or consumers of the block containing the parent operation.
//
// Blindly taking the operand types may thus result in a
// mismatch between the final return types of the parent op and
// the operand types of the return-like op.
//
// In the general case, simply take the result types of the
// parent operation, which, at this point, has already been
// partially rewritten before recursing into this rewrite call.
// Blindly applying the inferred operand types may thus result
// in a mismatch between the final return types of the parent op
// and the operand types of the return-like op.
//
// Functions are a bit different though, since the types of the
// results are contained in a function type and not in the
// result types.
mlir::Operation *newParent = mapping.lookup(op->getParentOp());

if (llvm::isa<mlir::func::ReturnOp>(op)) {
mlir::func::FuncOp newParentFunc =
llvm::dyn_cast<mlir::func::FuncOp>(newParent);
resolvedOperandTypes =
llvm::to_vector(newParentFunc.getFunctionType().getResults());
} else {
// Look up new parent op and use the return types, since these
// are the authoritative types obtained from the last invocation
// of the type resolver
resolvedOperandTypes = llvm::to_vector(newParent->getResultTypes());
// Instead, look up the rewritten parent op and add its return
// types to the local inference state before invoking type
// inference a last time.
if (op->hasTrait<mlir::OpTrait::ReturnLike>()) {
mlir::Operation *newParent = mapping.lookup(op->getParentOp());

for (auto [oldResult, newResult] : llvm::zip_equal(
op->getParentOp()->getResults(), newParent->getResults())) {
inferredTypes.set(oldResult, newResult.getType());
}
}
} else {
LocalInferenceState inferredTypes =
TypeInferenceUtils::getLocalInferenceState(solver, op);

resolvedTypes = typeResolver.resolve(op, inferredTypes);

Expand Down
Loading