Skip to content

Commit

Permalink
Keep next_prover of caller on Call
Browse files Browse the repository at this point in the history
  • Loading branch information
akokoshn authored and nkaskov committed Dec 6, 2023
1 parent 0a4aa15 commit b1f7c1b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
38 changes: 16 additions & 22 deletions include/nil/blueprint/parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,20 @@ namespace nil {
}
}
if(!is_const && llvm::isa<llvm::Constant>(inst->getOperand(i))) {
put_constant(llvm::cast<llvm::Constant>(inst->getOperand(i)), frame, next_prover);
put_constant(llvm::cast<llvm::Constant>(inst->getOperand(i)), frame, true);
}
}
return true;
}


std::string extract_metadata(const llvm::Instruction *inst) {
std::uint32_t extract_prover_idx_metadata(const llvm::Instruction *inst) {
const llvm::MDNode* metaDataNode = inst->getMetadata("zk_multi_prover");
if (metaDataNode) {
const llvm::MDString *MDS = llvm::dyn_cast<llvm::MDString>(metaDataNode->getOperand(0));
return MDS->getString().str();
return std::stoi(MDS->getString().str());
}
return "";
return currProverIdx;
}

template<typename map_type>
Expand Down Expand Up @@ -400,7 +400,7 @@ namespace nil {
for (int i = 0; i < inst->getNumOperands(); ++i) {
llvm::Value *op = inst->getOperand(i);
if (llvm::isa<llvm::Constant>(op)) {
put_constant(llvm::cast<llvm::Constant>(op), frame, next_prover);
put_constant(llvm::cast<llvm::Constant>(op), frame, true);
}
}
}
Expand Down Expand Up @@ -811,13 +811,7 @@ namespace nil {
std::uint32_t start_row = assignments[currProverIdx].allocated_rows();

// extract zk related metadata
const std::string metadataStr = extract_metadata(inst);
std::uint32_t userProverIdx = currProverIdx;
try {
userProverIdx = std::stoi(metadataStr);
} catch(...) {
userProverIdx = currProverIdx;
}
std::uint32_t userProverIdx = extract_prover_idx_metadata(inst);

if (userProverIdx < currProverIdx || userProverIdx >= maxNumProvers) {
std::cout << "WARNING: ignored user defined prover index " << userProverIdx
Expand All @@ -835,13 +829,8 @@ namespace nil {

bool next_prover = false;
if (inst->getNextNonDebugInstruction()) {
const std::string nextInstructionMetadataStr = extract_metadata(inst->getNextNonDebugInstruction());
try {
const std::uint32_t nextUserProverIdx = std::stoi(nextInstructionMetadataStr);
next_prover = currProverIdx != nextUserProverIdx;
} catch(...) {
next_prover = false;
}
const std::uint32_t nextUserProverIdx = extract_prover_idx_metadata(inst->getNextNonDebugInstruction());
next_prover = currProverIdx != nextUserProverIdx;
}

// Put constant operands to public input
Expand All @@ -858,7 +847,7 @@ namespace nil {
// In other cases the logic remains unchanged
if (inst->getOpcode() != llvm::Instruction::Call ||
!llvm::cast<llvm::CallInst>(inst)->getCalledFunction()->isIntrinsic()) {
put_constant(llvm::cast<llvm::Constant>(op), frame, next_prover);
put_constant(llvm::cast<llvm::Constant>(op), frame, true);
}
}
}
Expand Down Expand Up @@ -1028,6 +1017,7 @@ namespace nil {

}
new_frame.caller = call_inst;
new_frame.next_prover = next_prover;
call_stack.emplace(std::move(new_frame));
stack_memory.push_frame();
return &fun->begin()->front();
Expand Down Expand Up @@ -1346,13 +1336,15 @@ namespace nil {

return nullptr;
}
next_prover = next_prover || extracted_frame.next_prover;
if (inst->getNumOperands() != 0) {
llvm::Value *ret_val = inst->getOperand(0);
llvm::Type *ret_type= ret_val->getType();
if (ret_type->isVectorTy() || ret_type->isCurveTy()
|| (ret_type->isFieldTy() && field_arg_num<BlueprintFieldType>(ret_type) > 1)) {
auto &upper_frame_vectors = call_stack.top().vectors;
auto res = extracted_frame.vectors[ret_val];
auto res = next_prover ?
save_shared_var(assignments[currProverIdx], extracted_frame.vectors[ret_val]) : extracted_frame.vectors[ret_val];
upper_frame_vectors[extracted_frame.caller] = res;
} else if (ret_type->isAggregateType()) {
ptr_type ret_ptr = resolve_number<ptr_type>(extracted_frame, ret_val);
Expand All @@ -1368,7 +1360,8 @@ namespace nil {
upper_frame_variables[extracted_frame.caller] = put_into_assignment(allocated_copy, next_prover);
} else {
auto &upper_frame_variables = call_stack.top().scalars;
upper_frame_variables[extracted_frame.caller] = extracted_frame.scalars[ret_val];
upper_frame_variables[extracted_frame.caller] = next_prover ?
save_shared_var(assignments[currProverIdx], extracted_frame.scalars[ret_val]) : extracted_frame.scalars[ret_val];
}
}
return extracted_frame.caller->getNextNonDebugInstruction();
Expand All @@ -1395,6 +1388,7 @@ namespace nil {
stack_frame<var> base_frame;
auto &variables = base_frame.scalars;
base_frame.caller = nullptr;
base_frame.next_prover = false;
auto entry_point_it = module.end();
for (auto function_it = module.begin(); function_it != module.end(); ++function_it) {
if (function_it->hasFnAttribute(llvm::Attribute::Circuit)) {
Expand Down
1 change: 1 addition & 0 deletions include/nil/blueprint/stack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace nil {
std::map<const llvm::Value *, VarType> scalars;
std::map<const llvm::Value *, std::vector<VarType>> vectors;
const llvm::CallInst *caller;
bool next_prover;
};

} // namespace blueprint
Expand Down

0 comments on commit b1f7c1b

Please sign in to comment.