Skip to content

Commit

Permalink
revamp gadget.transit_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle-Kyle committed Feb 5, 2025
1 parent 8c2e26d commit 951103e
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 133 deletions.
242 changes: 117 additions & 125 deletions angrop/gadget_finder/gadget_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _analyze_gadget(self, addr, allow_conditional_branches):

# Step 3: gadget effect analysis
l.debug("... analyzing rop potential of block")
gadget = self._create_gadget(addr, init_state, final_state, ctrl_type)
gadget = self._create_gadget(addr, init_state, final_state, ctrl_type, allow_conditional_branches)
if not gadget:
continue

Expand Down Expand Up @@ -286,47 +286,23 @@ def _try_stepping_past_syscall(self, state):
except Exception: # pylint:disable=broad-exception-caught
return state

def _identify_transit_type(self, final_state, ctrl_type):
# FIXME: not always jump, could be call as well
if ctrl_type == 'register':
return "jmp_reg"
if ctrl_type == 'syscall':
return ctrl_type

if ctrl_type == 'pivot':
# FIXME: this logic feels wrong
variables = list(final_state.ip.variables)
if all(x.startswith("sreg_") for x in variables):
@staticmethod
def _control_to_transit_type(ctrl_type):
match ctrl_type:
case 'syscall':
return None
case 'pivot':
return None
case 'register':
return "jmp_reg"
for act in final_state.history.actions:
if act.type != 'mem':
continue
if act.size != self.project.arch.bits:
continue
if (act.data.ast == final_state.ip).symbolic or \
not final_state.solver.eval(act.data.ast == final_state.ip):
continue
sols = final_state.solver.eval_upto(final_state.regs.sp-act.addr.ast, 2)
if len(sols) != 1:
continue
if sols[0] != final_state.arch.bytes:
continue
return "ret"
return "pop_pc"

assert ctrl_type == 'stack'

v = final_state.memory.load(final_state.regs.sp - final_state.arch.bytes,
size=final_state.arch.bytes,
endness=final_state.arch.memory_endness)
if v is final_state.ip:
return "ret"

return "pop_pc"

def _create_gadget(self, addr, init_state, final_state, ctrl_type):
transit_type = self._identify_transit_type(final_state, ctrl_type)

case 'stack':
return 'pop_pc'
case 'memory':
return "jmp_mem"
case _:
raise ValueError("Unknown control type")

def _create_gadget(self, addr, init_state, final_state, ctrl_type, do_cond_branch):
# create the gadget
if ctrl_type == 'syscall' or self._does_syscall(final_state):
# gadgets that do syscall and pivoting are too complicated
Expand All @@ -342,44 +318,46 @@ def _create_gadget(self, addr, init_state, final_state, ctrl_type):

# FIXME this doesnt handle multiple steps
gadget.block_length = self.project.factory.block(addr).size
gadget.transit_type = transit_type

# for jmp_reg gadget, record the jump target register
if transit_type == "jmp_reg":
gadget.pc_reg = list(final_state.ip.variables)[0].split('_', 1)[1].rsplit('-')[0]

# compute sp change
l.debug("... computing sp change")
self._compute_sp_change(init_state, final_state, gadget)
if gadget.stack_change % (self.project.arch.bytes) != 0:
if (gadget.stack_change % self.project.arch.bytes) != 0:
l.debug("... uneven sp change")
return None
if gadget.stack_change < 0:
l.debug("stack change is negative!!")
#FIXME: technically, it can be negative, e.g. call instructions
return None

# record pc_offset
if type(gadget) is not PivotGadget and transit_type in ['pop_pc', 'ret']:
idx = list(final_state.ip.variables)[0].split('_')[2]
gadget.pc_offset = int(idx) * self.project.arch.bytes
if gadget.pc_offset >= gadget.stack_change:
return None
# transit_type-based handling
transit_type = self._control_to_transit_type(ctrl_type)
gadget.transit_type = transit_type
match transit_type:
case 'pop_pc': # record pc_offset
idx = list(final_state.ip.variables)[0].split('_')[2]
gadget.pc_offset = int(idx) * self.project.arch.bytes
if gadget.pc_offset >= gadget.stack_change:
return None
case 'jmp_reg': # record pc_reg
gadget.pc_reg = list(final_state.ip.variables)[0].split('_', 1)[1].rsplit('-')[0]
case 'jmp_mem': # record pc_target
for a in reversed(final_state.history.actions):
if a.type == 'mem' and a.action == 'read':
if (a.data.ast == final_state.ip).is_true():
gadget.pc_target = a.addr.ast
break

# register effect analysis
l.info("... checking for controlled regs")
self._check_reg_changes(final_state, init_state, gadget)

# check for reg moves
# get reg reads
reg_reads = self._get_reg_reads(final_state)
l.debug("... checking for reg moves")
self._check_reg_change_dependencies(init_state, final_state, gadget)
self._check_reg_movers(init_state, final_state, reg_reads, gadget)

# check concretized registers
self._analyze_concrete_regs(init_state, final_state, gadget)

# check mem accesses
# memory access analysis
l.debug("... analyzing mem accesses")
if not self._analyze_mem_access(final_state, init_state, gadget):
l.debug("... too many symbolic memory accesses")
Expand All @@ -390,39 +368,40 @@ def _create_gadget(self, addr, init_state, final_state, ctrl_type):
l.debug("... mem access with no addr dependencies")
return None

# Store block address list for gadgets with conditional branches
gadget.bbl_addrs = list(final_state.history.bbl_addrs)
gadget.isn_count = sum(self.project.factory.block(addr).instructions for addr in gadget.bbl_addrs)

constraint_vars = {
var
for constraint in final_state.history.jump_guards
for var in constraint.variables
}

gadget.has_conditional_branch = len(constraint_vars) > 0

for action in final_state.history.actions:
if action.type == 'mem':
constraint_vars |= action.addr.variables

for var in constraint_vars:
if var.startswith("sreg_"):
gadget.constraint_regs.add(var.split('_', 1)[1].split('-', 1)[0])
elif not var.startswith("symbolic_stack_"):
l.debug("... constraint not controlled by registers and stack")
return None
# conditional branch analysis
if do_cond_branch:
constraint_vars = {
var
for constraint in final_state.history.jump_guards
for var in constraint.variables
}

gadget.has_conditional_branch = len(constraint_vars) > 0

for action in final_state.history.actions:
if action.type == 'mem':
constraint_vars |= action.addr.variables

for var in constraint_vars:
if var.startswith("sreg_"):
gadget.constraint_regs.add(var.split('_', 1)[1].split('-', 1)[0])
elif not var.startswith("symbolic_stack_"):
l.debug("... constraint not controlled by registers and stack")
return None

gadget.popped_regs = {
reg
for reg in gadget.popped_regs
if final_state.registers.load(reg).variables.isdisjoint(constraint_vars)
}
gadget.popped_regs = {
reg
for reg in gadget.popped_regs
if final_state.registers.load(reg).variables.isdisjoint(constraint_vars)
}

gadget.popped_reg_vars = {
reg: final_state.registers.load(reg).variables
for reg in gadget.popped_regs
}
gadget.popped_reg_vars = {
reg: final_state.registers.load(reg).variables
for reg in gadget.popped_regs
}

return gadget

Expand Down Expand Up @@ -520,62 +499,36 @@ def _check_reg_movers(self, symbolic_state, symbolic_p, reg_reads, gadget):
if ast_1 is ast_2:
gadget.reg_moves.append(RopRegMove(from_reg, reg, half_bits))

# TODO: need to handle reg calls
def _check_for_control_type(self, init_state, final_state):
"""
:return: the data provenance of the controlled ip in the final state, either the stack or registers
:return: the data provenance of the controlled ip in the final state
"""

ip = final_state.ip

# this gadget arrives a syscall
# this gadget arrives at a syscall
if self.is_in_kernel(final_state):
return 'syscall'

# the ip is controlled by stack
# the ip is controlled by stack (ret)
if self._check_if_stack_controls_ast(ip, init_state):
return "stack"

# the ip is not controlled by regs
# the ip is not controlled by regs/mem
if not ip.variables:
return None
ip_variables = list(ip.variables)

# the ip is fully controlled by regs
variables = list(ip.variables)
if all(x.startswith("sreg_") for x in variables):
# the ip is fully controlled by regs (jmp rax)
if all(x.startswith("sreg_") for x in ip_variables):
return "register"

# this is a stack pivoting gadget
if all(x.startswith("symbolic_read_") for x in variables) and len(final_state.regs.sp.variables) == 1:
# we don't fully control sp
if not init_state.solver.satisfiable(extra_constraints=[final_state.regs.sp == 0x41414100]):
return None
# make sure the control after pivot is reasonable

# find where the ip is read from
saved_ip_addr = None
for act in final_state.history.actions:
if act.type == 'mem' and act.action == 'read':
if (
act.size == self.project.arch.bits
and isinstance(act.data.ast, claripy.ast.BV)
and not (act.data.ast == ip).symbolic
):
if init_state.solver.eval(act.data.ast == ip):
saved_ip_addr = act.addr.ast
break
if saved_ip_addr is None:
return None
# the ip is fully controlled by memory and sp is not symbolic (jmp [rax])
if all(x.startswith("symbolic_read_") for x in ip_variables) and not final_state.regs.sp.symbolic:
return "memory"

# if the saved ip is too far away from the final sp, that's a bad gadget
sols = final_state.solver.eval_upto(final_state.regs.sp - saved_ip_addr, 2)
if len(sols) != 1: # the saved ip has a symbolic distance from the final sp, bad
return None
offset = sols[0]
if offset > self._stack_bsize: # filter out gadgets like mov rsp, rax; ret 0x1000
return None
if offset % self.project.arch.bytes != 0: # filter misaligned gadgets
return None
# this is a stack pivoting gadget
if self._check_if_stack_pivot(init_state, final_state):
return "pivot"

return None
Expand Down Expand Up @@ -628,6 +581,45 @@ def _check_if_stack_controls_ast(self, ast, initial_state, gadget_stack_change=N

return ans

def _check_if_stack_pivot(self, init_state, final_state):
ip_variables = list(final_state.ip.variables)
if any(not x.startswith("symbolic_read_") for x in ip_variables):
return None
if len(final_state.regs.sp.variables) != 1:
return None

# check if we fully control sp
if not init_state.solver.satisfiable(extra_constraints=[final_state.regs.sp == 0x41414100]):
return None

# make sure the control after pivot is reasonable

# find where the ip is read from
saved_ip_addr = None
for act in final_state.history.actions:
if act.type == 'mem' and act.action == 'read':
if (
act.size == self.project.arch.bits
and isinstance(act.data.ast, claripy.ast.BV)
and not (act.data.ast == ip).symbolic
):
if init_state.solver.eval(act.data.ast == ip):
saved_ip_addr = act.addr.ast
break
if saved_ip_addr is None:
return None

# if the saved ip is too far away from the final sp, that's a bad gadget
sols = final_state.solver.eval_upto(final_state.regs.sp - saved_ip_addr, 2)
if len(sols) != 1: # the saved ip has a symbolic distance from the final sp, bad
return None
offset = sols[0]
if offset > self._stack_bsize: # filter out gadgets like mov rsp, rax; ret 0x1000
return None
if offset % self.project.arch.bytes != 0: # filter misaligned gadgets
return None
return "pivot"

def _to_signed(self, value):
bits = self.project.arch.bits
if value >> (bits-1): # if the MSB is 1, this value is negative
Expand Down
18 changes: 10 additions & 8 deletions angrop/rop_gadget.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,16 @@ def __init__(self, addr):
self.mem_writes = []
self.mem_changes = []

# TODO: pc shouldn't be treated differently from other registers
# it is just a register. With the register setting framework, we will be able to
# utilize gadgets like `call qword ptr [rax+rbx]` because we have the dependency information.
# transition information, i.e. how to pass the control flow to the next gadget
# gadget transition
# we now support the following gadget transitions
# 1. pop_pc: ret, jmp [sp+X], pop pc,X,Y, retn), this type of gadgets are "self-contained"
# 2. jmp_reg: jmp reg <- requires reg setting before using it (call falls here as well)
# 3. jmp_mem: jmp [reg+X] <- requires mem setting before using it (call is here as well)
self.transit_type: str = None # type: ignore
self.pc_reg = None
# pc_offset is exclusively used when transit_type is "pop_pc",
# when pc_offset==stack_change-arch_bytes, transit_type is basically ret
self.pc_offset = None

self.pc_offset = None # for pop_pc, ret is basically pc_offset == stack_change - arch.bytes
self.pc_reg = None # for jmp_reg, which register it jumps to
self.pc_target = None # for jmp_mem, where it jumps to

# List of basic block addresses for gadgets with conditional branches
self.bbl_addrs = []
Expand Down Expand Up @@ -274,6 +275,7 @@ class SyscallGadget(RopGadget):
def __init__(self, addr):
super().__init__(addr)
self.makes_syscall = False
# TODO: starts_with_syscall should be removed, not useful at all
self.starts_with_syscall = False

def __str__(self):
Expand Down

0 comments on commit 951103e

Please sign in to comment.