From 064c42452f7420ae3b4d690d6892ec821e8fcb60 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 19 Apr 2024 06:22:56 +0200 Subject: [PATCH] fix(compiler): Type inference rewriter: Handle return-like operations correctly Return-like operations require special treatment by the type inference rewriter, since their operand types are both tied to the result types of their producers and to result types of their parent operations. The inference scheme for ordinary operations, in which the initial local inference state is composed of the operand types of the rewritten producers and the old types of related operations before rewriting is insufficient, since this may result in a mismatch between the inferred types and the actual types of the already rewritten parent operation. Until now, precedence of the new result types of the parent operation has been implemented by simply designating these types as the operand types of a return-like operation. However, while this works as intended for return-like operations, which simply forward values (e.g., `func.return`), this creates invalid IR for other return-like operations (e.g., `tensor.yield`). This change implements precedence of the result types of the parent operation of a return-like operation by adding the return types of the already rewritten parent operation to the initial local inference state before final invocation of type inference. --- .../Transforms/TypeInferenceRewriter.h | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Transforms/TypeInferenceRewriter.h b/compilers/concrete-compiler/compiler/include/concretelang/Transforms/TypeInferenceRewriter.h index c874a69f6e..bef222cd66 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Transforms/TypeInferenceRewriter.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Transforms/TypeInferenceRewriter.h @@ -161,40 +161,32 @@ class TypeInferenceRewriter { })) { resolvedOperandTypes = llvm::to_vector(op->getOperandTypes()); resolvedResultTypes = llvm::to_vector(op->getResultTypes()); - } else if (op->hasTrait()) { - // Return-like ops are a bit special, since they simply forward - // values upwards. The types inferred for the operands may come - // from producers, which are in the same block, while the result + } else { + LocalInferenceState inferredTypes = + TypeInferenceUtils::getLocalInferenceState(solver, op); + + // Return-like ops are a bit special, since their operand types + // are tied to the types of their parent op. The types inferred + // for the operands of a return-like operation may come from + // producers, which are in the same block, while the result // types of the parent op may have been deduced from producers // or consumers of the block containing the parent operation. // - // Blindly taking the operand types may thus result in a - // mismatch between the final return types of the parent op and - // the operand types of the return-like op. - // - // In the general case, simply take the result types of the - // parent operation, which, at this point, has already been - // partially rewritten before recursing into this rewrite call. + // Blindly applying the inferred operand types may thus result + // in a mismatch between the final return types of the parent op + // and the operand types of the return-like op. // - // Functions are a bit different though, since the types of the - // results are contained in a function type and not in the - // result types. - mlir::Operation *newParent = mapping.lookup(op->getParentOp()); - - if (llvm::isa(op)) { - mlir::func::FuncOp newParentFunc = - llvm::dyn_cast(newParent); - resolvedOperandTypes = - llvm::to_vector(newParentFunc.getFunctionType().getResults()); - } else { - // Look up new parent op and use the return types, since these - // are the authoritative types obtained from the last invocation - // of the type resolver - resolvedOperandTypes = llvm::to_vector(newParent->getResultTypes()); + // Instead, look up the rewritten parent op and add its return + // types to the local inference state before invoking type + // inference a last time. + if (op->hasTrait()) { + mlir::Operation *newParent = mapping.lookup(op->getParentOp()); + + for (auto [oldResult, newResult] : llvm::zip_equal( + op->getParentOp()->getResults(), newParent->getResults())) { + inferredTypes.set(oldResult, newResult.getType()); + } } - } else { - LocalInferenceState inferredTypes = - TypeInferenceUtils::getLocalInferenceState(solver, op); resolvedTypes = typeResolver.resolve(op, inferredTypes);