diff --git a/angrop/chain_builder/func_caller.py b/angrop/chain_builder/func_caller.py index aec780c..b2096d3 100644 --- a/angrop/chain_builder/func_caller.py +++ b/angrop/chain_builder/func_caller.py @@ -187,10 +187,14 @@ def _func_call(self, func_gadget, cc, args, extra_regs=None, preserve_regs=None, # 2. handle function return address to maintain the control flow if stack_arguments: shift_bytes = (len(stack_arguments)+1)*arch_bytes + # TODO: currently, we only shift stack only for the minimal + # but if this shift fails, we should try larger shifts cleaner = self.chain_builder.shift(shift_bytes, next_pc_idx=-1, preserve_regs=preserve_regs) chain.add_gadget(cleaner._gadgets[0]) for arg in stack_arguments: chain.add_value(arg) + next_pc = claripy.BVS("next_pc", self.project.arch.bits) + chain.add_value(next_pc) # handle return address if not isinstance(cc.RETURN_ADDR, (SimStackArg, SimRegArg)): diff --git a/angrop/chain_builder/shifter.py b/angrop/chain_builder/shifter.py index d71d1ff..b5e809a 100644 --- a/angrop/chain_builder/shifter.py +++ b/angrop/chain_builder/shifter.py @@ -125,6 +125,8 @@ def _same_effect(self, g1, g2): return False if g1.transit_type != g2.transit_type: return False + if g1.pc_offset != g2.pc_offset: + return False return True def _better_than(self, g1, g2): diff --git a/angrop/rop_chain.py b/angrop/rop_chain.py index dc094d3..aec12ec 100644 --- a/angrop/rop_chain.py +++ b/angrop/rop_chain.py @@ -1,5 +1,7 @@ import logging +import claripy + from . import rop_utils from .errors import RopException from .rop_gadget import RopGadget @@ -46,11 +48,14 @@ def __add__(self, other): # add the other values and gadgets result._gadgets.extend(other._gadgets) idx = self.next_pc_idx() - assert idx is not None, "can't add to a chain that does not return!" + assert idx is not None or not self._values, "can't add to a chain that does not return!" result._payload_len = self._payload_len + other._payload_len - result._values[idx] = other._values[0] - result._values.extend(other._values[1:]) - result._payload_len -= self._p.arch.bytes + if idx is not None: + result._values[idx] = other._values[0] + result._values.extend(other._values[1:]) + result._payload_len -= self._p.arch.bytes + else: + result._values.extend(other._values) return result def set_timeout(self, timeout):