Skip to content

Commit

Permalink
performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
huettenhain committed Nov 24, 2024
1 parent b285348 commit 3e77b6d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 25 deletions.
2 changes: 2 additions & 0 deletions refinery/lib/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LogLevel(IntEnum):
an exception.
"""
NONE = logging.CRITICAL + 50
PROFILE = logging.CRITICAL + 10

@classmethod
def FromVerbosity(cls, verbosity: int):
Expand Down Expand Up @@ -73,6 +74,7 @@ class RefineryFormatter(logging.Formatter):
logging.WARNING : 'warning',
logging.INFO : 'comment',
logging.DEBUG : 'verbose',
LogLevel.PROFILE : 'profile',
}

def __init__(self, format, **kwargs):
Expand Down
14 changes: 6 additions & 8 deletions refinery/units/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,15 +2121,14 @@ def run(cls: Union[Type[Unit], Executable], argv=None, stream=None) -> None:
from time import process_time
argv.remove(cls._SECRET_DEBUG_TIMING_FLAG)
clock = process_time()
cls.logger.setLevel(LogLevel.INFO)
cls.logger.info('starting clock: {:.4f}'.format(clock))
cls.logger.log(LogLevel.PROFILE, 'starting clock: {:.4f}'.format(clock))

if cls._SECRET_YAPPI_TIMING_FLAG in argv:
argv.remove(cls._SECRET_YAPPI_TIMING_FLAG)
try:
import yappi as _yappi
except ImportError:
cls.logger.warn('unable to start yappi; package is missing')
cls.logger.log(LogLevel.PROFILE, 'unable to start yappi; package is missing')
else:
yappi = _yappi

Expand Down Expand Up @@ -2159,8 +2158,7 @@ def run(cls: Union[Type[Unit], Executable], argv=None, stream=None) -> None:
unit.log_level = loglevel

if clock:
unit.log_level = min(unit.log_level, LogLevel.INFO)
unit.logger.info('unit launching: {:.4f}'.format(clock))
cls.logger.log(LogLevel.PROFILE, 'unit launching: {:.4f}'.format(clock))

if yappi is not None:
yappi.set_clock_type('cpu')
Expand All @@ -2184,12 +2182,12 @@ def run(cls: Union[Type[Unit], Executable], argv=None, stream=None) -> None:
stats = yappi.get_func_stats()
filename = F'{unit.name}.perf'
stats.save(filename, type='CALLGRIND')
cls.logger.info(F'wrote yappi results to file: {filename}')
cls.logger.log(LogLevel.PROFILE, F'wrote yappi results to file: {filename}')

if clock:
stop_clock = process_time()
unit.logger.info('stopping clock: {:.4f}'.format(stop_clock))
unit.logger.info('time delta was: {:.4f}'.format(stop_clock - clock))
cls.logger.log(LogLevel.PROFILE, 'stopping clock: {:.4f}'.format(stop_clock))
cls.logger.log(LogLevel.PROFILE, 'time delta was: {:.4f}'.format(stop_clock - clock))


__pdoc__ = {
Expand Down
66 changes: 49 additions & 17 deletions refinery/units/formats/exe/vstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,24 @@ def _get_reg_size(mu: Uc, reg: int):
return q


@dataclass
class EmuConfig:
wait_calls: bool
skip_calls: bool
write_range: slice
wait: int
block_size: int
stack_size: int
log_stack_cookies: bool
log_writes_in_calls: bool
log_stack_addresses: bool
log_other_addresses: bool
log_zero_overwrites: bool


@dataclass
class EmuState:
cfg: EmuConfig
executable: Executable
writes: IntervalTree
expected_address: int
Expand Down Expand Up @@ -357,8 +373,24 @@ def get_register_id(var: str):
self.log_info(F'error mapping segment [{vmem.lower:0{width}X}-{vmem.upper:0{width}X}]: {error!s}')

tree = self._intervaltree.IntervalTree()
args = self.args

cfg = EmuConfig(
args.wait_calls,
args.skip_calls,
args.write_range,
args.wait,
args.block_size,
args.stack_size,
args.log_stack_cookies,
args.log_writes_in_calls,
args.log_stack_addresses,
args.log_other_addresses,
args.log_zero_overwrites,
)

state = EmuState(
exe, tree, address, stack, blob, disassembler,
cfg, exe, tree, address, stack, blob, disassembler,
stop=self.args.stop,
sp_register=sp,
ip_register=ip,
Expand Down Expand Up @@ -386,6 +418,7 @@ def get_register_id(var: str):
except uc.UcError:
pass

tree.merge_overlaps()
it: Iterator[Interval] = iter(tree)
for interval in it:
size = interval.end - interval.begin - 1
Expand Down Expand Up @@ -413,7 +446,7 @@ def _hook_mem_write(self, emu: Uc, access: int, address: int, size: int, value:
if unsigned_value == state.expected_address:
callstack = state.callstack
state.retaddr = unsigned_value
if not self.args.skip_calls:
if not state.cfg.skip_calls:
if not callstack:
state.callstack_ceiling = emu.reg_read(state.sp_register)
callstack.append(unsigned_value)
Expand All @@ -424,21 +457,21 @@ def _hook_mem_write(self, emu: Uc, access: int, address: int, size: int, value:
skipped = False

if (
not self.args.log_stack_cookies
not state.cfg.log_stack_cookies
and emu.reg_read(state.sp_register) ^ unsigned_value == state.last_read
):
skipped = 'stack cookie'
elif size not in bounds[self.args.write_range]:
elif size not in bounds[state.cfg.write_range]:
skipped = 'size excluded'
elif (
state.callstack_ceiling > 0
and not self.args.log_writes_in_calls
and not state.cfg.log_writes_in_calls
and address in range(state.callstack_ceiling - 0x200, state.callstack_ceiling)
):
skipped = 'inside call'
elif not self.args.log_stack_addresses and unsigned_value in state.stack:
elif not state.cfg.log_stack_addresses and unsigned_value in state.stack:
skipped = 'stack address'
elif not self.args.log_other_addresses and not state.blob:
elif not state.cfg.log_other_addresses and not state.blob:
for s in state.executable.sections():
if address in s.virtual:
skipped = F'write to section {s.name}'
Expand All @@ -448,7 +481,7 @@ def _hook_mem_write(self, emu: Uc, access: int, address: int, size: int, value:
not skipped
and unsigned_value == 0
and state.writes.at(address) is not None
and self.args.log_zero_overwrites is False
and state.cfg.log_zero_overwrites is False
):
try:
if any(emu.mem_read(address, size)):
Expand All @@ -458,7 +491,6 @@ def _hook_mem_write(self, emu: Uc, access: int, address: int, size: int, value:

if not skipped:
state.writes.addi(address, address + size + 1)
state.writes.merge_overlaps()
state.waiting = 0

def info():
Expand All @@ -484,7 +516,7 @@ def _hook_insn_error(self, emu: Uc, state: EmuState):
return False

def _hook_mem_error(self, emu: Uc, access: int, address: int, size: int, value: int, state: EmuState):
bs = self.args.block_size
bs = state.cfg.block_size
try:
emu.mem_map(align(bs, address, down=True), 2 * bs)
except Exception:
Expand Down Expand Up @@ -514,10 +546,10 @@ def _hook_code(self, emu: Uc, address: int, size: int, state: EmuState):
state.retaddr = None

if address != state.expected_address:
if retaddr is not None and self.args.skip_calls:
if self.args.skip_calls > 1:
stack_size = self.args.stack_size
block_size = self.args.block_size
if retaddr is not None and state.cfg.skip_calls:
if state.cfg.skip_calls > 1:
stack_size = state.cfg.stack_size
block_size = state.cfg.block_size
rv = state.rv_register
alloc_addr = align(block_size, state.allocations[-1].upper)
state.allocations.append(Range(alloc_addr, alloc_addr + stack_size))
Expand All @@ -535,17 +567,17 @@ def _hook_code(self, emu: Uc, address: int, size: int, state: EmuState):
if depth == 0:
state.callstack_ceiling = 0
state.expected_address = address
elif retaddr is not None and not self.args.skip_calls:
elif retaddr is not None and not state.cfg.skip_calls:
# The present address was moved to the stack but we did not branch.
# This is not quite accurate, of course: We could be calling the
# next instruction. However, that sort of code is usually not really
# a function call anyway, but rather a way to get the IP.
callstack.pop()

if waiting > self.args.wait:
if waiting > state.cfg.wait:
emu.emu_stop()
return False
if not depth or not self.args.wait_calls:
if not depth or not state.cfg.wait_calls:
state.waiting += 1
state.expected_address += size

Expand Down

0 comments on commit 3e77b6d

Please sign in to comment.