diff --git a/include/nil/blueprint/parser.hpp b/include/nil/blueprint/parser.hpp index aed64b6c..2251ec81 100644 --- a/include/nil/blueprint/parser.hpp +++ b/include/nil/blueprint/parser.hpp @@ -139,20 +139,20 @@ namespace nil { } } if(!is_const && llvm::isa(inst->getOperand(i))) { - put_constant(llvm::cast(inst->getOperand(i)), frame, next_prover); + put_constant(llvm::cast(inst->getOperand(i)), frame, true); } } return true; } - std::string extract_metadata(const llvm::Instruction *inst) { + std::uint32_t extract_prover_idx_metadata(const llvm::Instruction *inst) { const llvm::MDNode* metaDataNode = inst->getMetadata("zk_multi_prover"); if (metaDataNode) { const llvm::MDString *MDS = llvm::dyn_cast(metaDataNode->getOperand(0)); - return MDS->getString().str(); + return std::stoi(MDS->getString().str()); } - return ""; + return currProverIdx; } template @@ -400,7 +400,7 @@ namespace nil { for (int i = 0; i < inst->getNumOperands(); ++i) { llvm::Value *op = inst->getOperand(i); if (llvm::isa(op)) { - put_constant(llvm::cast(op), frame, next_prover); + put_constant(llvm::cast(op), frame, true); } } } @@ -811,13 +811,7 @@ namespace nil { std::uint32_t start_row = assignments[currProverIdx].allocated_rows(); // extract zk related metadata - const std::string metadataStr = extract_metadata(inst); - std::uint32_t userProverIdx = currProverIdx; - try { - userProverIdx = std::stoi(metadataStr); - } catch(...) { - userProverIdx = currProverIdx; - } + std::uint32_t userProverIdx = extract_prover_idx_metadata(inst); if (userProverIdx < currProverIdx || userProverIdx >= maxNumProvers) { std::cout << "WARNING: ignored user defined prover index " << userProverIdx @@ -835,13 +829,8 @@ namespace nil { bool next_prover = false; if (inst->getNextNonDebugInstruction()) { - const std::string nextInstructionMetadataStr = extract_metadata(inst->getNextNonDebugInstruction()); - try { - const std::uint32_t nextUserProverIdx = std::stoi(nextInstructionMetadataStr); - next_prover = currProverIdx != nextUserProverIdx; - } catch(...) { - next_prover = false; - } + const std::uint32_t nextUserProverIdx = extract_prover_idx_metadata(inst->getNextNonDebugInstruction()); + next_prover = currProverIdx != nextUserProverIdx; } // Put constant operands to public input @@ -858,7 +847,7 @@ namespace nil { // In other cases the logic remains unchanged if (inst->getOpcode() != llvm::Instruction::Call || !llvm::cast(inst)->getCalledFunction()->isIntrinsic()) { - put_constant(llvm::cast(op), frame, next_prover); + put_constant(llvm::cast(op), frame, true); } } } @@ -1028,6 +1017,7 @@ namespace nil { } new_frame.caller = call_inst; + new_frame.next_prover = next_prover; call_stack.emplace(std::move(new_frame)); stack_memory.push_frame(); return &fun->begin()->front(); @@ -1346,13 +1336,15 @@ namespace nil { return nullptr; } + next_prover = next_prover || extracted_frame.next_prover; if (inst->getNumOperands() != 0) { llvm::Value *ret_val = inst->getOperand(0); llvm::Type *ret_type= ret_val->getType(); if (ret_type->isVectorTy() || ret_type->isCurveTy() || (ret_type->isFieldTy() && field_arg_num(ret_type) > 1)) { auto &upper_frame_vectors = call_stack.top().vectors; - auto res = extracted_frame.vectors[ret_val]; + auto res = next_prover ? + save_shared_var(assignments[currProverIdx], extracted_frame.vectors[ret_val]) : extracted_frame.vectors[ret_val]; upper_frame_vectors[extracted_frame.caller] = res; } else if (ret_type->isAggregateType()) { ptr_type ret_ptr = resolve_number(extracted_frame, ret_val); @@ -1368,7 +1360,8 @@ namespace nil { upper_frame_variables[extracted_frame.caller] = put_into_assignment(allocated_copy, next_prover); } else { auto &upper_frame_variables = call_stack.top().scalars; - upper_frame_variables[extracted_frame.caller] = extracted_frame.scalars[ret_val]; + upper_frame_variables[extracted_frame.caller] = next_prover ? + save_shared_var(assignments[currProverIdx], extracted_frame.scalars[ret_val]) : extracted_frame.scalars[ret_val]; } } return extracted_frame.caller->getNextNonDebugInstruction(); @@ -1395,6 +1388,7 @@ namespace nil { stack_frame base_frame; auto &variables = base_frame.scalars; base_frame.caller = nullptr; + base_frame.next_prover = false; auto entry_point_it = module.end(); for (auto function_it = module.begin(); function_it != module.end(); ++function_it) { if (function_it->hasFnAttribute(llvm::Attribute::Circuit)) { diff --git a/include/nil/blueprint/stack.hpp b/include/nil/blueprint/stack.hpp index dd42fb7b..d0bf6f6b 100644 --- a/include/nil/blueprint/stack.hpp +++ b/include/nil/blueprint/stack.hpp @@ -49,6 +49,7 @@ namespace nil { std::map scalars; std::map> vectors; const llvm::CallInst *caller; + bool next_prover; }; } // namespace blueprint