Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve virutal register usage for optimizations. #60

Merged
merged 5 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions backend/remill/include/remill/BC/HelperMacro.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
#pragma once

// #define LIFT_DEBUG 1
// #define LIFT_CALLSTACK_DEBUG 1
// #define LIFT_INSN_DEBUG 1
// #define LIFT_MEMORY_VALUE_CHANGE 1
// #define ELFCONV_SYSCALL_DEBUG 1

// #define WARNING_OUTPUT 1
// #define OPT_ALGO_DEBUG 1
// #define OPT_GEN_IR_DEBUG 1
// #define OPT_CALL_FUNC_DEBUG 1
// #define OPT_REAL_REGS_DEBUG 1
// #define OPT_REAL_REGS_DEBUG 1
10 changes: 5 additions & 5 deletions backend/remill/include/remill/BC/InstructionLifter.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class EcvReg {
static std::pair<EcvReg, EcvRegClass> GetRegInfo(const std::string &_reg_name);

std::string GetRegName(EcvRegClass ecv_reg_class) const;
bool CheckNoChangedReg() const;
bool CheckPassedArgsRegs() const;
bool CheckPassedReturnRegs() const;

class Hash {
public:
Expand Down Expand Up @@ -195,6 +196,8 @@ class BBRegInfoNode {
// Save the args registers by semantic functions (for debug)
std::unordered_map<llvm::CallInst *, std::vector<std::pair<EcvReg, EcvRegClass>>>
sema_func_args_reg_map;
// Save the pc of semantics functions (for debug)
std::unordered_map<llvm::CallInst *, uint64_t> sema_func_pc_map;

std::unordered_map<llvm::Value *, std::pair<EcvReg, EcvRegClass>> post_update_regs;

Expand All @@ -214,14 +217,12 @@ class InstructionLifterIntf : public OperandLifter {
// this instruction will execute within the delay slot of another instruction.
virtual LiftStatus LiftIntoBlock(Instruction &inst, llvm::BasicBlock *block,
llvm::Value *state_ptr, BBRegInfoNode *bb_reg_info_node,
uint64_t debug_insn_addr = UINT64_MAX,
bool is_delayed = false) = 0;

// Lift a single instruction into a basic block. `is_delayed` signifies that
// this instruction will execute within the delay slot of another instruction.
LiftStatus LiftIntoBlock(Instruction &inst, llvm::BasicBlock *block,
BBRegInfoNode *bb_reg_info_node, uint64_t debug_insn_addr = UINT64_MAX,
bool is_delayed = false);
BBRegInfoNode *bb_reg_info_node, bool is_delayed = false);
};

// Wraps the process of lifting an instruction into a block. This resolves
Expand All @@ -244,7 +245,6 @@ class InstructionLifter : public InstructionLifterIntf {
// this instruction will execute within the delay slot of another instruction.
virtual LiftStatus LiftIntoBlock(Instruction &inst, llvm::BasicBlock *block,
llvm::Value *state_ptr, BBRegInfoNode *bb_reg_info_node,
uint64_t debug_insn_addr = UINT64_MAX,
bool is_delayed = false) override;


Expand Down
2 changes: 1 addition & 1 deletion backend/remill/include/remill/BC/SleighLifter.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class SleighLifterWithState final : public InstructionLifterIntf {
// this instruction will execute within the delay slot of another instruction.
virtual LiftStatus LiftIntoBlock(Instruction &inst, llvm::BasicBlock *block,
llvm::Value *state_ptr, BBRegInfoNode *bb_reg_info_node,
uint64_t __debug_insn_addr, bool is_delayed = false) override;
bool is_delayed = false) override;

virtual llvm::Value *LoadRegValueBeforeInst(llvm::BasicBlock *block, llvm::Value *state_ptr,
std::string_view reg_name,
Expand Down
63 changes: 49 additions & 14 deletions backend/remill/include/remill/BC/TraceLifter.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ class TraceManager {
class PhiRegsBBBagNode {
public:
PhiRegsBBBagNode(EcvRegMap<EcvRegClass> __preceding_load_reg_map,
EcvRegMap<EcvRegClass> &&__succeeding_load_reg_map,
EcvRegMap<EcvRegClass> __succeeding_load_reg_map,
EcvRegMap<EcvRegClass> &&__within_store_reg_map,
std::set<llvm::BasicBlock *> &&__in_bbs)
: bag_preceding_load_reg_map(std::move(__preceding_load_reg_map)),
bag_succeeding_load_reg_map(std::move(__succeeding_load_reg_map)),
: bag_preceding_load_reg_map(__preceding_load_reg_map),
bag_succeeding_load_reg_map(__succeeding_load_reg_map),
bag_preceding_store_reg_map(std::move(__within_store_reg_map)),
in_bbs(std::move(__in_bbs)),
converted_bag(nullptr) {}
Expand All @@ -127,11 +127,14 @@ class PhiRegsBBBagNode {
static void GetPrecedingVirtualRegsBags(llvm::BasicBlock *root_bb);
static void GetSucceedingVirtualRegsBags(llvm::BasicBlock *root_bb);
static void RemoveLoop(llvm::BasicBlock *bb);
static void GetPhiRegsBags(llvm::BasicBlock *root_bb);
static void
GetPhiRegsBags(llvm::BasicBlock *root_bb,
std::unordered_map<llvm::BasicBlock *, BBRegInfoNode *> &bb_info_node_map);

static inline std::unordered_map<llvm::BasicBlock *, PhiRegsBBBagNode *> bb_regs_bag_map = {};
static inline std::size_t bag_num = 0;
static inline std::unordered_map<PhiRegsBBBagNode *, uint32_t> debug_bag_map = {};
// The register set which should be passed from caller function.

PhiRegsBBBagNode *GetTrueBag();
void MergePrecedingRegMap(PhiRegsBBBagNode *moved_bag);
Expand Down Expand Up @@ -196,16 +199,37 @@ class VirtualRegsOpt {
relay_bb_cache({}),
phi_val_order(0),
fun_vma(__fun_vma) {
for (auto &arg : func->args()) {
if (arg.getName() == "state") {
arg_state_val = &arg;
} else if (arg.getName() == "runtime_manager") {
arg_runtime_val = &arg;
arg_state_val = NULL;
arg_runtime_val = NULL;
// only declared function.
if (func->getName().str() == "__remill_function_call") {
auto args = func->args().begin();
for (size_t i = 0; i < func->arg_size(); i++) {
if (0 == i) {
CHECK(llvm::dyn_cast<llvm::PointerType>(args[i].getType()));
arg_state_val = &args[i];
} else if (2 == i) {
CHECK(llvm::dyn_cast<llvm::PointerType>(args[i].getType()));
arg_runtime_val = &args[i];
}
}
}
CHECK(arg_state_val) << "[Bug] state arg is empty at the initialization of VirtualRegsOpt.";
// lifted function.
else {
for (auto &arg : func->args()) {
if (arg.getName() == "state") {
arg_state_val = &arg;
} else if (arg.getName() == "runtime_manager") {
arg_runtime_val = &arg;
}
}
}
CHECK(arg_state_val)
<< "[Bug] state arg is empty at the initialization of VirtualRegsOpt. target func: "
<< func->getName().str();
CHECK(arg_runtime_val)
<< "[Bug] runtime_manager arg is empty at the initialization of VirtualRegsOpt.";
<< "[Bug] runtime_manager arg is empty at the initialization of VirtualRegsOpt. target func: "
<< func->getName().str();
}
VirtualRegsOpt() {}
~VirtualRegsOpt() {}
Expand All @@ -223,8 +247,15 @@ class VirtualRegsOpt {
std::unordered_map<EcvReg, std::tuple<EcvRegClass, llvm::Value *, uint32_t>, EcvReg::Hash>
&cache_map);

void AnalyzeRegsBags();
static void CalPassedCallerRegForBJump();

void OptimizeVirtualRegsUsage();

static inline std::unordered_map<llvm::Function *, VirtualRegsOpt *> func_v_r_opt_map = {};
static inline std::unordered_map<llvm::Function *, std::vector<llvm::Function *>>
b_jump_callees_map = {};

llvm::Function *func;
TraceLifter::Impl *impl;
llvm::Value *arg_state_val;
Expand All @@ -242,13 +273,19 @@ class VirtualRegsOpt {

uint64_t phi_val_order;

std::unordered_map<llvm::BasicBlock *, PhiRegsBBBagNode *> bb_regs_bag_map;
EcvRegMap<EcvRegClass> passed_caller_reg_map;
EcvRegMap<EcvRegClass> passed_callee_ret_reg_map;

std::set<llvm::ReturnInst *> ret_inst_set;

// for debug
uint64_t fun_vma;
uint64_t block_num;
std::string func_name;
// map llvm::Value* and the corresponding CPU register.
std::unordered_map<llvm::Value *, std::pair<EcvReg, EcvRegClass>> value_reg_map;
static inline std::set<EcvReg> debug_reg_set = {};
std::set<EcvReg> debug_reg_set = {};

void InsertDebugVmaAndRegisters(
llvm::Instruction *inst_at_before,
Expand Down Expand Up @@ -365,15 +402,13 @@ class TraceLifter::Impl {
std::string inst_bytes;
Instruction inst;
Instruction delayed_inst;
std::set<uint64_t> control_flow_debug_fnvma_set;
DecoderWorkList trace_work_list;
DecoderWorkList inst_work_list;
DecoderWorkList dead_inst_work_list;
uint64_t __trace_addr;
std::map<uint64_t, llvm::BasicBlock *> blocks;
VirtualRegsOpt *virtual_regs_opt;

std::unordered_map<llvm::Function *, VirtualRegsOpt *> func_virtual_regs_opt_map;
std::set<llvm::Function *> no_indirect_lifted_funcs;
std::set<llvm::Function *> lifted_funcs;

Expand Down
43 changes: 25 additions & 18 deletions backend/remill/lib/BC/InstructionLifter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,12 @@ std::string EcvReg::GetRegName(EcvRegClass ecv_reg_class) const {
return "";
}

bool EcvReg::CheckNoChangedReg() const {
return STATE_ORDER == number || RUNTIME_ORDER == number || IGNORE_WRITE_TO_WZR_ORDER == number ||
IGNORE_WRITE_TO_XZR_ORDER == number;
bool EcvReg::CheckPassedArgsRegs() const {
return (0 <= number && number <= 7) || SP_ORDER == number;
}

bool EcvReg::CheckPassedReturnRegs() const {
return (0 <= number && number <= 1) || SP_ORDER == number;
}

std::string EcvRegClass2String(EcvRegClass ecv_reg_class) {
Expand Down Expand Up @@ -234,7 +237,9 @@ BBRegInfoNode::BBRegInfoNode(llvm::Function *func, llvm::Value *state_val,
void BBRegInfoNode::join_reg_info_node(BBRegInfoNode *child) {
// Join bb_load_reg_map
for (auto [_ecv_reg, _ecv_reg_class] : child->bb_load_reg_map) {
bb_load_reg_map.insert({_ecv_reg, _ecv_reg_class});
if (!bb_store_reg_map.contains(_ecv_reg)) {
bb_load_reg_map.insert({_ecv_reg, _ecv_reg_class});
}
}
// Join bb_store_reg_map
for (auto [_ecv_reg, _ecv_reg_class] : child->bb_store_reg_map) {
Expand All @@ -258,6 +263,10 @@ void BBRegInfoNode::join_reg_info_node(BBRegInfoNode *child) {
for (auto key_value : child->sema_func_args_reg_map) {
sema_func_args_reg_map.insert(key_value);
}
// Join sema_func_pc_map
for (auto key_value : child->sema_func_pc_map) {
sema_func_pc_map.insert(key_value);
}
}

InstructionLifter::Impl::Impl(const Arch *arch_, const IntrinsicTable *intrinsics_)
Expand Down Expand Up @@ -286,10 +295,9 @@ InstructionLifter::InstructionLifter(const Arch *arch_, const IntrinsicTable *in
// Lift a single instruction into a basic block. `is_delayed` signifies that
// this instruction will execute within the delay slot of another instruction.
LiftStatus InstructionLifterIntf::LiftIntoBlock(Instruction &inst, llvm::BasicBlock *block,
BBRegInfoNode *bb_reg_info_node,
uint64_t debug_insn_addr, bool is_delayed) {
BBRegInfoNode *bb_reg_info_node, bool is_delayed) {
return LiftIntoBlock(inst, block, NthArgument(block->getParent(), kStatePointerArgNum),
bb_reg_info_node, debug_insn_addr, is_delayed);
bb_reg_info_node, is_delayed);
}

llvm::Type *get_llvm_type(llvm::LLVMContext &context, EcvRegClass ecv_reg_class) {
Expand Down Expand Up @@ -338,8 +346,7 @@ llvm::Type *get_llvm_type(llvm::LLVMContext &context, EcvRegClass ecv_reg_class)
// Lift a single instruction into a basic block.
LiftStatus InstructionLifter::LiftIntoBlock(Instruction &arch_inst, llvm::BasicBlock *block,
llvm::Value *state_ptr, BBRegInfoNode *bb_reg_info_node,

uint64_t debug_insn_addr, bool is_delayed) {
bool is_delayed) {
llvm::Function *const func = block->getParent();
llvm::Module *const module = func->getParent();
auto &context = func->getContext();
Expand Down Expand Up @@ -565,6 +572,8 @@ LiftStatus InstructionLifter::LiftIntoBlock(Instruction &arch_inst, llvm::BasicB
<< "Unexpected to multiple lift the call instruction.";
bb_reg_info_node->sema_call_written_reg_map.insert({sema_inst, write_regs});

bb_reg_info_node->sema_func_pc_map.insert({sema_inst, arch_inst.pc});

// Update pre-post index for the target register.
if (!arch_inst.updated_addr_reg.name.empty()) {
const auto [update_reg_ptr_reg, _] =
Expand Down Expand Up @@ -610,17 +619,15 @@ LiftStatus InstructionLifter::LiftIntoBlock(Instruction &arch_inst, llvm::BasicB
// ir.CreateStore(ir.CreateCall(impl->intrinsics->delay_slot_end, temp_args), mem_ptr_ref);
}

/* append debug_insn function call */
if (UINT64_MAX != debug_insn_addr) {
llvm::IRBuilder<> __debug_ir(block);
/* append `debug_memory_value_change` function call */
#if defined(LIFT_MEMORY_VALUE_CHANGE)
auto _debug_memory_value_change_fn = module->getFunction(debug_memory_value_change_name);
auto [runtime_manager_ptr, _] = LoadRegAddress(block, state_ptr, kRuntimeVariableName);
__debug_ir.CreateCall(_debug_memory_value_change_fn,
{__debug_ir.CreateLoad(llvm::Type::getInt64PtrTy(module->getContext()),
runtime_manager_ptr)});
llvm::IRBuilder<> __debug_ir(block);
auto _debug_memory_value_change_fn = module->getFunction(debug_memory_value_change_name);
auto [runtime_manager_ptr, _] = LoadRegAddress(block, state_ptr, kRuntimeVariableName);
__debug_ir.CreateCall(_debug_memory_value_change_fn,
{__debug_ir.CreateLoad(llvm::Type::getInt64PtrTy(module->getContext()),
runtime_manager_ptr)});
#endif
}

return status;
}
Expand Down
3 changes: 1 addition & 2 deletions backend/remill/lib/BC/SleighLifter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1665,8 +1665,7 @@ SleighLifterWithState::SleighLifterWithState(sleigh::MaybeBranchTakenVar btaken_
// this instruction will execute within the delay slot of another instruction.
LiftStatus SleighLifterWithState::LiftIntoBlock(Instruction &inst, llvm::BasicBlock *block,
llvm::Value *state_ptr,
BBRegInfoNode *bb_reg_info_node,
uint64_t __debug_insn_addr, bool is_delayed) {
BBRegInfoNode *bb_reg_info_node, bool is_delayed) {
return this->lifter->LiftIntoBlockWithSleighState(inst, block, state_ptr, is_delayed,
this->btaken, this->context_values);
}
Expand Down
Loading
Loading