Skip to content

Commit

Permalink
Merge pull request #65 from mrexodia/exception-hooks
Browse files Browse the repository at this point in the history
Implement exception hooks
  • Loading branch information
mrexodia authored Mar 16, 2023
2 parents edc4ec6 + 3e2509b commit b22e8e3
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 66 deletions.
2 changes: 1 addition & 1 deletion src/dumpulator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .dumpulator import Dumpulator
from .dumpulator import Dumpulator, ExceptionType, MemoryViolation, ExceptionInfo
from .ntsyscalls import syscall
167 changes: 105 additions & 62 deletions src/dumpulator/dumpulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import traceback
from enum import Enum
from typing import List, Union, NamedTuple
from typing import List, Union, NamedTuple, Callable
import inspect
from collections import OrderedDict
from dataclasses import dataclass, field
Expand Down Expand Up @@ -39,22 +39,41 @@ class ExceptionType(Enum):
ContextSwitch = 3
Terminate = 4

class MemoryViolation(Enum):
Unknown = 0
ReadUnmapped = 1
WriteUnmapped = 2
ExecuteUnmapped = 3
ReadProtect = 4
WriteProtect = 5
ExecuteProtect = 6
ReadUnaligned = 7
WriteUnaligned = 8
ExecuteUnaligned = 9

@dataclass
class ExceptionInfo:
type: ExceptionType = ExceptionType.NoException
memory_access: int = 0 # refers to UC_MEM_* values
# type == ExceptionType.Memory
memory_violation: MemoryViolation = MemoryViolation.Unknown
memory_address: int = 0
memory_size: int = 0
memory_value: int = 0
# type == ExceptionType.Interrupt
interrupt_number: int = 0

# Internal state
_handling: bool = False

@dataclass
class UnicornExceptionInfo(ExceptionInfo):
final: bool = False
code_hook_h: Optional[int] = None # represents a `unicorn.uc_hook_h` value (from uc.hook_add)
context: Optional[unicorn.UcContext] = None
tb_start: int = 0
tb_size: int = 0
tb_icount: int = 0
step_count: int = 0
final: bool = False
handling: bool = False

def __str__(self):
return f"{self.type}, ({hex(self.tb_start)}, {hex(self.tb_size)}, {self.tb_icount})"
Expand Down Expand Up @@ -305,8 +324,9 @@ def __init__(self, minidump_file, *, trace=False, quiet=False, thread_id=None, d
self.kill_me = None
self.exit_code = None
self.exports = self._all_exports()
self.exception = ExceptionInfo()
self.last_exception: Optional[ExceptionInfo] = None
self._exception = UnicornExceptionInfo()
self._last_exception: Optional[UnicornExceptionInfo] = None
self._exception_hook: Optional[Callable[[ExceptionInfo], Optional[int]]] = None
if not self._quiet:
print("Memory map:")
self.print_memory()
Expand Down Expand Up @@ -896,15 +916,28 @@ def allocate(self, size, page_align=False):
self.memory.commit(self.memory.align_page(ptr), self.memory.align_page(size))
return ptr

def handle_exception(self):
assert not self.exception.handling
self.exception.handling = True
def set_exception_hook(self, exception_hook: Optional[Callable[[ExceptionInfo], Optional[int]]]):
previous_hook = self._exception_hook
self._exception_hook = exception_hook
return previous_hook

if self.exception.type == ExceptionType.ContextSwitch:
def handle_exception(self):
assert not self._exception._handling
self._exception._handling = True

if self._exception_hook is not None:
hook_result = self._exception_hook(self._exception)
if hook_result is not None:
# Clear the pending exception
self._last_exception = self._exception
self._exception = UnicornExceptionInfo()
return hook_result

if self._exception.type == ExceptionType.ContextSwitch:
self.info(f"context switch, cip: {hex(self.regs.cip)}")
# Clear the pending exception
self.last_exception = self.exception
self.exception = ExceptionInfo()
self._last_exception = self._exception
self._exception = UnicornExceptionInfo()
# NOTE: the context has already been restored using context_restore in the caller
return self.regs.cip

Expand Down Expand Up @@ -961,22 +994,23 @@ def handle_exception(self):
context_ex.XState.Offset = 0xF0 if self._x64 else 0x20
context_ex.XState.Length = 0x160 if self._x64 else 0x140
record = record_type()
if self.exception.type == ExceptionType.Memory:
record.ExceptionCode = 0xC0000005
alignment_violations = [MemoryViolation.ReadUnaligned, MemoryViolation.WriteUnaligned, MemoryViolation.ExecuteUnaligned]
if self._exception.type == ExceptionType.Memory and self._exception.memory_violation not in alignment_violations:
record.ExceptionCode = STATUS_ACCESS_VIOLATION
record.ExceptionFlags = 0
record.ExceptionAddress = self.regs.cip
record.NumberParameters = 2
types = {
UC_MEM_READ_UNMAPPED: EXCEPTION_READ_FAULT,
UC_MEM_WRITE_UNMAPPED: EXCEPTION_WRITE_FAULT,
UC_MEM_FETCH_UNMAPPED: EXCEPTION_READ_FAULT,
UC_MEM_READ_PROT: EXCEPTION_READ_FAULT,
UC_MEM_WRITE_PROT: EXCEPTION_WRITE_FAULT,
UC_MEM_FETCH_PROT: EXCEPTION_EXECUTE_FAULT,
MemoryViolation.ReadUnmapped: EXCEPTION_READ_FAULT,
MemoryViolation.WriteUnmapped: EXCEPTION_WRITE_FAULT,
MemoryViolation.ExecuteUnmapped: EXCEPTION_READ_FAULT,
MemoryViolation.ReadProtect: EXCEPTION_READ_FAULT,
MemoryViolation.WriteProtect: EXCEPTION_WRITE_FAULT,
MemoryViolation.ExecuteProtect: EXCEPTION_EXECUTE_FAULT,
}
record.ExceptionInformation[0] = types[self.exception.memory_access]
record.ExceptionInformation[1] = self.exception.memory_address
elif self.exception.type == ExceptionType.Interrupt and self.exception.interrupt_number == 3:
record.ExceptionInformation[0] = types[self._exception.memory_violation]
record.ExceptionInformation[1] = self._exception.memory_address
elif self._exception.type == ExceptionType.Interrupt and self._exception.interrupt_number == 3:
if self._x64:
context.Rip -= 1 # TODO: long int3 and prefixes
record.ExceptionCode = 0x80000003
Expand All @@ -990,11 +1024,11 @@ def handle_exception(self):
record.ExceptionAddress = context.Eip
record.NumberParameters = 1
else:
raise NotImplementedError(f"{self.exception}") # TODO: implement
raise NotImplementedError(f"{self._exception}") # TODO: implement

# Clear the pending exception
self.last_exception = self.exception
self.exception = ExceptionInfo()
self._last_exception = self._exception
self._exception = UnicornExceptionInfo()

def write_stack(cur_ptr: int, data: bytes):
self.write(cur_ptr, data)
Expand Down Expand Up @@ -1024,19 +1058,19 @@ def write_stack(cur_ptr: int, data: bytes):

def start(self, begin, end=0xffffffffffffffff, count=0) -> None:
# Clear exceptions before starting
self.exception = ExceptionInfo()
self._exception = UnicornExceptionInfo()
emu_begin = begin
emu_until = end
emu_count = count
while True:
try:
if self.exception.type != ExceptionType.NoException:
if self.exception.final:
if self._exception.type != ExceptionType.NoException:
if self._exception.final:
# Restore the context (unicorn might mess with it before stopping)
if self.exception.context is not None:
self._uc.context_restore(self.exception.context)
if self._exception.context is not None:
self._uc.context_restore(self._exception.context)

if self.exception.type == ExceptionType.Terminate:
if self._exception.type == ExceptionType.Terminate:
if self.exit_code is not None:
self.info(f"exit code: {hex(self.exit_code)}")
break
Expand All @@ -1051,20 +1085,20 @@ def start(self, begin, end=0xffffffffffffffff, count=0) -> None:
emu_count = 0
else:
# If this happens there was an error restarting simulation
assert self.exception.step_count == 0
assert self._exception.step_count == 0

# Hook should be installed at this point
assert self.exception.code_hook_h is not None
assert self._exception.code_hook_h is not None

# Restore the context (unicorn might mess with it before stopping)
assert self.exception.context is not None
self._uc.context_restore(self.exception.context)
assert self._exception.context is not None
self._uc.context_restore(self._exception.context)

# Restart emulation
self.info(f"restarting emulation to handle exception...")
emu_begin = self.regs.cip
emu_until = 0xffffffffffffffff
emu_count = self.exception.tb_icount + 1
emu_count = self._exception.tb_icount + 1

self.info(f"emu_start({hex(emu_begin)}, {hex(emu_until)}, {emu_count})")
self.kill_me = None
Expand All @@ -1076,7 +1110,7 @@ def start(self, begin, end=0xffffffffffffffff, count=0) -> None:
except UcError as err:
if self.kill_me is not None and type(self.kill_me) is not UcError:
raise self.kill_me
if self.exception.type != ExceptionType.NoException:
if self._exception.type != ExceptionType.NoException:
# Handle the exception outside of the except handler
continue
else:
Expand Down Expand Up @@ -1232,7 +1266,7 @@ def load_dll(self, file_name: str, file_data: bytes):
def _hook_code_exception(uc: Uc, address, size, dp: Dumpulator):
try:
dp.info(f"exception step: {hex(address)}[{size}]")
ex = dp.exception
ex = dp._exception
ex.step_count += 1
if ex.step_count >= ex.tb_icount:
raise Exception("Stepped past the basic block without reaching exception")
Expand All @@ -1246,18 +1280,27 @@ def _hook_mem(uc: Uc, access, address, size, value, dp: Dumpulator):
return True

fetch_accesses = [UC_MEM_FETCH, UC_MEM_FETCH_PROT, UC_MEM_FETCH_UNMAPPED]
if access == UC_MEM_FETCH_UNMAPPED and address >= FORCE_KILL_ADDR - 0x10 and address <= FORCE_KILL_ADDR + 0x10 and dp.kill_me is not None:
if access == UC_MEM_FETCH_UNMAPPED and FORCE_KILL_ADDR - 0x10 <= address <= FORCE_KILL_ADDR + 0x10 and dp.kill_me is not None:
dp.error(f"forced exit memory operation {access} of {hex(address)}[{hex(size)}] = {hex(value)}")
return False
if dp.exception.final and access in fetch_accesses:
if dp._exception.final and access in fetch_accesses:
dp.info(f"fetch from {hex(address)}[{size}] already reported")
return False
# TODO: figure out why when you start executing at 0 this callback is triggered more than once
try:
violation = {
UC_MEM_READ_UNMAPPED: MemoryViolation.ReadUnmapped,
UC_MEM_WRITE_UNMAPPED: MemoryViolation.WriteUnmapped,
UC_MEM_FETCH_UNMAPPED: MemoryViolation.ExecuteUnmapped,
UC_MEM_READ_PROT: MemoryViolation.ReadProtect,
UC_MEM_WRITE_PROT: MemoryViolation.WriteProtect,
UC_MEM_FETCH_PROT: MemoryViolation.ExecuteProtect,
}.get(access, MemoryViolation.Unknown)
assert violation != MemoryViolation.Unknown, f"Unexpected memory access {access}"
# Extract exception information
exception = ExceptionInfo()
exception = UnicornExceptionInfo()
exception.type = ExceptionType.Memory
exception.memory_access = access
exception.memory_violation = violation
exception.memory_address = address
exception.memory_size = size
exception.memory_value = value
Expand All @@ -1269,7 +1312,7 @@ def _hook_mem(uc: Uc, access, address, size, value, dp: Dumpulator):
exception.tb_icount = tb.icount

# Print exception info
final = dp.trace or dp.exception.code_hook_h is not None
final = dp.trace or dp._exception.code_hook_h is not None
info = "final" if final else "initial"
if access == UC_MEM_READ_UNMAPPED:
dp.error(f"{info} unmapped read from {hex(address)}[{hex(size)}], cip = {hex(dp.regs.cip)}, exception: {exception}")
Expand All @@ -1295,25 +1338,25 @@ def _hook_mem(uc: Uc, access, address, size, value, dp: Dumpulator):
if final:
# Make sure this is the same exception we expect
if not dp.trace:
assert access == dp.exception.memory_access
assert address == dp.exception.memory_address
assert size == dp.exception.memory_size
assert value == dp.exception.memory_value
assert violation == dp._exception.memory_violation
assert address == dp._exception.memory_address
assert size == dp._exception.memory_size
assert value == dp._exception.memory_value

# Delete the code hook
uc.hook_del(dp.exception.code_hook_h)
dp.exception.code_hook_h = None
uc.hook_del(dp._exception.code_hook_h)
dp._exception.code_hook_h = None

# At this point we know for sure the context is correct so we can report the exception
dp.exception = exception
dp.exception.final = True
dp._exception = exception
dp._exception.final = True

# Stop emulation (we resume it on KiUserExceptionDispatcher later)
dp.stop()
return False

# There should not be an exception active
assert dp.exception.type == ExceptionType.NoException
assert dp._exception.type == ExceptionType.NoException

# Remove the translation block cache for this block
# Without doing this single stepping the block won't work
Expand All @@ -1325,7 +1368,7 @@ def _hook_mem(uc: Uc, access, address, size, value, dp: Dumpulator):
exception.code_hook_h = uc.hook_add(UC_HOOK_CODE, _hook_code_exception, user_data=dp)

# Store the exception info
dp.exception = exception
dp._exception = exception

# Stop emulation (we resume execution later)
dp.stop()
Expand Down Expand Up @@ -1452,7 +1495,7 @@ def _arg_type_string(arg):
def _hook_interrupt(uc: Uc, number, dp: Dumpulator):
try:
# Extract exception information
exception = ExceptionInfo()
exception = UnicornExceptionInfo()
exception.type = ExceptionType.Interrupt
exception.interrupt_number = number
exception.context = uc.context_save()
Expand All @@ -1470,11 +1513,11 @@ def _hook_interrupt(uc: Uc, number, dp: Dumpulator):
dp.error(f"interrupt {number} ({description}), cip = {hex(dp.regs.cip)}, cs = {hex(dp.regs.cs)}")

# There should not be an exception active
assert dp.exception.type == ExceptionType.NoException
assert dp._exception.type == ExceptionType.NoException

# At this point we know for sure the context is correct so we can report the exception
dp.exception = exception
dp.exception.final = True
dp._exception = exception
dp._exception.final = True
except AssertionError as err:
traceback.print_exc()
raise err
Expand Down Expand Up @@ -1560,7 +1603,7 @@ def syscall_arg(index):
status = syscall_impl(dp, *args)
if isinstance(status, ExceptionInfo):
print("context switch, stopping emulation")
dp.exception = status
dp._exception = status
raise dp.raise_kill(UcError(UC_ERR_EXCEPTION)) from None
else:
dp.info(f"status = {hex(status)}")
Expand Down Expand Up @@ -1610,11 +1653,11 @@ def _hook_invalid(uc: Uc, dp: Dumpulator):
instr = next(dp.cs.disasm(code, address, 1))
if _emulate_unsupported_instruction(dp, instr):
# Resume execution with a context switch
assert dp.exception.type == ExceptionType.NoException
exception = ExceptionInfo()
assert dp._exception.type == ExceptionType.NoException
exception = UnicornExceptionInfo()
exception.type = ExceptionType.ContextSwitch
exception.final = True
dp.exception = exception
dp._exception = exception
return False # NOTE: returning True would stop emulation
except StopIteration:
pass # Unsupported instruction
Expand Down
1 change: 1 addition & 0 deletions src/dumpulator/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# NTSTATUS
STATUS_SUCCESS = 0
STATUS_NOT_IMPLEMENTED = 0xC0000002
STATUS_ACCESS_VIOLATION = 0xC0000005
STATUS_INVALID_HANDLE = 0xC0000008
STATUS_NO_SUCH_FILE = 0xC000000F
STATUS_ACCESS_DENIED = 0xC0000022
Expand Down
Loading

0 comments on commit b22e8e3

Please sign in to comment.