diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Transforms/TypeInferenceRewriter.h b/compilers/concrete-compiler/compiler/include/concretelang/Transforms/TypeInferenceRewriter.h index c874a69f6e..bef222cd66 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Transforms/TypeInferenceRewriter.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Transforms/TypeInferenceRewriter.h @@ -161,40 +161,32 @@ class TypeInferenceRewriter { })) { resolvedOperandTypes = llvm::to_vector(op->getOperandTypes()); resolvedResultTypes = llvm::to_vector(op->getResultTypes()); - } else if (op->hasTrait()) { - // Return-like ops are a bit special, since they simply forward - // values upwards. The types inferred for the operands may come - // from producers, which are in the same block, while the result + } else { + LocalInferenceState inferredTypes = + TypeInferenceUtils::getLocalInferenceState(solver, op); + + // Return-like ops are a bit special, since their operand types + // are tied to the types of their parent op. The types inferred + // for the operands of a return-like operation may come from + // producers, which are in the same block, while the result // types of the parent op may have been deduced from producers // or consumers of the block containing the parent operation. // - // Blindly taking the operand types may thus result in a - // mismatch between the final return types of the parent op and - // the operand types of the return-like op. - // - // In the general case, simply take the result types of the - // parent operation, which, at this point, has already been - // partially rewritten before recursing into this rewrite call. + // Blindly applying the inferred operand types may thus result + // in a mismatch between the final return types of the parent op + // and the operand types of the return-like op. // - // Functions are a bit different though, since the types of the - // results are contained in a function type and not in the - // result types. - mlir::Operation *newParent = mapping.lookup(op->getParentOp()); - - if (llvm::isa(op)) { - mlir::func::FuncOp newParentFunc = - llvm::dyn_cast(newParent); - resolvedOperandTypes = - llvm::to_vector(newParentFunc.getFunctionType().getResults()); - } else { - // Look up new parent op and use the return types, since these - // are the authoritative types obtained from the last invocation - // of the type resolver - resolvedOperandTypes = llvm::to_vector(newParent->getResultTypes()); + // Instead, look up the rewritten parent op and add its return + // types to the local inference state before invoking type + // inference a last time. + if (op->hasTrait()) { + mlir::Operation *newParent = mapping.lookup(op->getParentOp()); + + for (auto [oldResult, newResult] : llvm::zip_equal( + op->getParentOp()->getResults(), newParent->getResults())) { + inferredTypes.set(oldResult, newResult.getType()); + } } - } else { - LocalInferenceState inferredTypes = - TypeInferenceUtils::getLocalInferenceState(solver, op); resolvedTypes = typeResolver.resolve(op, inferredTypes);