diff --git a/src/dumpulator/__init__.py b/src/dumpulator/__init__.py index 65428e4..e65118a 100644 --- a/src/dumpulator/__init__.py +++ b/src/dumpulator/__init__.py @@ -1,2 +1,2 @@ -from .dumpulator import Dumpulator +from .dumpulator import Dumpulator, ExceptionType, MemoryViolation, ExceptionInfo from .ntsyscalls import syscall diff --git a/src/dumpulator/dumpulator.py b/src/dumpulator/dumpulator.py index 8e0c2a2..8871c4b 100644 --- a/src/dumpulator/dumpulator.py +++ b/src/dumpulator/dumpulator.py @@ -3,7 +3,7 @@ import sys import traceback from enum import Enum -from typing import List, Union, NamedTuple +from typing import List, Union, NamedTuple, Callable import inspect from collections import OrderedDict from dataclasses import dataclass, field @@ -39,22 +39,41 @@ class ExceptionType(Enum): ContextSwitch = 3 Terminate = 4 +class MemoryViolation(Enum): + Unknown = 0 + ReadUnmapped = 1 + WriteUnmapped = 2 + ExecuteUnmapped = 3 + ReadProtect = 4 + WriteProtect = 5 + ExecuteProtect = 6 + ReadUnaligned = 7 + WriteUnaligned = 8 + ExecuteUnaligned = 9 + @dataclass class ExceptionInfo: type: ExceptionType = ExceptionType.NoException - memory_access: int = 0 # refers to UC_MEM_* values + # type == ExceptionType.Memory + memory_violation: MemoryViolation = MemoryViolation.Unknown memory_address: int = 0 memory_size: int = 0 memory_value: int = 0 + # type == ExceptionType.Interrupt interrupt_number: int = 0 + + # Internal state + _handling: bool = False + +@dataclass +class UnicornExceptionInfo(ExceptionInfo): + final: bool = False code_hook_h: Optional[int] = None # represents a `unicorn.uc_hook_h` value (from uc.hook_add) context: Optional[unicorn.UcContext] = None tb_start: int = 0 tb_size: int = 0 tb_icount: int = 0 step_count: int = 0 - final: bool = False - handling: bool = False def __str__(self): return f"{self.type}, ({hex(self.tb_start)}, {hex(self.tb_size)}, {self.tb_icount})" @@ -305,8 +324,9 @@ def __init__(self, minidump_file, *, trace=False, quiet=False, thread_id=None, d self.kill_me = None self.exit_code = None self.exports = self._all_exports() - self.exception = ExceptionInfo() - self.last_exception: Optional[ExceptionInfo] = None + self._exception = UnicornExceptionInfo() + self._last_exception: Optional[UnicornExceptionInfo] = None + self._exception_hook: Optional[Callable[[ExceptionInfo], Optional[int]]] = None if not self._quiet: print("Memory map:") self.print_memory() @@ -896,15 +916,28 @@ def allocate(self, size, page_align=False): self.memory.commit(self.memory.align_page(ptr), self.memory.align_page(size)) return ptr - def handle_exception(self): - assert not self.exception.handling - self.exception.handling = True + def set_exception_hook(self, exception_hook: Optional[Callable[[ExceptionInfo], Optional[int]]]): + previous_hook = self._exception_hook + self._exception_hook = exception_hook + return previous_hook - if self.exception.type == ExceptionType.ContextSwitch: + def handle_exception(self): + assert not self._exception._handling + self._exception._handling = True + + if self._exception_hook is not None: + hook_result = self._exception_hook(self._exception) + if hook_result is not None: + # Clear the pending exception + self._last_exception = self._exception + self._exception = UnicornExceptionInfo() + return hook_result + + if self._exception.type == ExceptionType.ContextSwitch: self.info(f"context switch, cip: {hex(self.regs.cip)}") # Clear the pending exception - self.last_exception = self.exception - self.exception = ExceptionInfo() + self._last_exception = self._exception + self._exception = UnicornExceptionInfo() # NOTE: the context has already been restored using context_restore in the caller return self.regs.cip @@ -961,22 +994,23 @@ def handle_exception(self): context_ex.XState.Offset = 0xF0 if self._x64 else 0x20 context_ex.XState.Length = 0x160 if self._x64 else 0x140 record = record_type() - if self.exception.type == ExceptionType.Memory: - record.ExceptionCode = 0xC0000005 + alignment_violations = [MemoryViolation.ReadUnaligned, MemoryViolation.WriteUnaligned, MemoryViolation.ExecuteUnaligned] + if self._exception.type == ExceptionType.Memory and self._exception.memory_violation not in alignment_violations: + record.ExceptionCode = STATUS_ACCESS_VIOLATION record.ExceptionFlags = 0 record.ExceptionAddress = self.regs.cip record.NumberParameters = 2 types = { - UC_MEM_READ_UNMAPPED: EXCEPTION_READ_FAULT, - UC_MEM_WRITE_UNMAPPED: EXCEPTION_WRITE_FAULT, - UC_MEM_FETCH_UNMAPPED: EXCEPTION_READ_FAULT, - UC_MEM_READ_PROT: EXCEPTION_READ_FAULT, - UC_MEM_WRITE_PROT: EXCEPTION_WRITE_FAULT, - UC_MEM_FETCH_PROT: EXCEPTION_EXECUTE_FAULT, + MemoryViolation.ReadUnmapped: EXCEPTION_READ_FAULT, + MemoryViolation.WriteUnmapped: EXCEPTION_WRITE_FAULT, + MemoryViolation.ExecuteUnmapped: EXCEPTION_READ_FAULT, + MemoryViolation.ReadProtect: EXCEPTION_READ_FAULT, + MemoryViolation.WriteProtect: EXCEPTION_WRITE_FAULT, + MemoryViolation.ExecuteProtect: EXCEPTION_EXECUTE_FAULT, } - record.ExceptionInformation[0] = types[self.exception.memory_access] - record.ExceptionInformation[1] = self.exception.memory_address - elif self.exception.type == ExceptionType.Interrupt and self.exception.interrupt_number == 3: + record.ExceptionInformation[0] = types[self._exception.memory_violation] + record.ExceptionInformation[1] = self._exception.memory_address + elif self._exception.type == ExceptionType.Interrupt and self._exception.interrupt_number == 3: if self._x64: context.Rip -= 1 # TODO: long int3 and prefixes record.ExceptionCode = 0x80000003 @@ -990,11 +1024,11 @@ def handle_exception(self): record.ExceptionAddress = context.Eip record.NumberParameters = 1 else: - raise NotImplementedError(f"{self.exception}") # TODO: implement + raise NotImplementedError(f"{self._exception}") # TODO: implement # Clear the pending exception - self.last_exception = self.exception - self.exception = ExceptionInfo() + self._last_exception = self._exception + self._exception = UnicornExceptionInfo() def write_stack(cur_ptr: int, data: bytes): self.write(cur_ptr, data) @@ -1024,19 +1058,19 @@ def write_stack(cur_ptr: int, data: bytes): def start(self, begin, end=0xffffffffffffffff, count=0) -> None: # Clear exceptions before starting - self.exception = ExceptionInfo() + self._exception = UnicornExceptionInfo() emu_begin = begin emu_until = end emu_count = count while True: try: - if self.exception.type != ExceptionType.NoException: - if self.exception.final: + if self._exception.type != ExceptionType.NoException: + if self._exception.final: # Restore the context (unicorn might mess with it before stopping) - if self.exception.context is not None: - self._uc.context_restore(self.exception.context) + if self._exception.context is not None: + self._uc.context_restore(self._exception.context) - if self.exception.type == ExceptionType.Terminate: + if self._exception.type == ExceptionType.Terminate: if self.exit_code is not None: self.info(f"exit code: {hex(self.exit_code)}") break @@ -1051,20 +1085,20 @@ def start(self, begin, end=0xffffffffffffffff, count=0) -> None: emu_count = 0 else: # If this happens there was an error restarting simulation - assert self.exception.step_count == 0 + assert self._exception.step_count == 0 # Hook should be installed at this point - assert self.exception.code_hook_h is not None + assert self._exception.code_hook_h is not None # Restore the context (unicorn might mess with it before stopping) - assert self.exception.context is not None - self._uc.context_restore(self.exception.context) + assert self._exception.context is not None + self._uc.context_restore(self._exception.context) # Restart emulation self.info(f"restarting emulation to handle exception...") emu_begin = self.regs.cip emu_until = 0xffffffffffffffff - emu_count = self.exception.tb_icount + 1 + emu_count = self._exception.tb_icount + 1 self.info(f"emu_start({hex(emu_begin)}, {hex(emu_until)}, {emu_count})") self.kill_me = None @@ -1076,7 +1110,7 @@ def start(self, begin, end=0xffffffffffffffff, count=0) -> None: except UcError as err: if self.kill_me is not None and type(self.kill_me) is not UcError: raise self.kill_me - if self.exception.type != ExceptionType.NoException: + if self._exception.type != ExceptionType.NoException: # Handle the exception outside of the except handler continue else: @@ -1232,7 +1266,7 @@ def load_dll(self, file_name: str, file_data: bytes): def _hook_code_exception(uc: Uc, address, size, dp: Dumpulator): try: dp.info(f"exception step: {hex(address)}[{size}]") - ex = dp.exception + ex = dp._exception ex.step_count += 1 if ex.step_count >= ex.tb_icount: raise Exception("Stepped past the basic block without reaching exception") @@ -1246,18 +1280,27 @@ def _hook_mem(uc: Uc, access, address, size, value, dp: Dumpulator): return True fetch_accesses = [UC_MEM_FETCH, UC_MEM_FETCH_PROT, UC_MEM_FETCH_UNMAPPED] - if access == UC_MEM_FETCH_UNMAPPED and address >= FORCE_KILL_ADDR - 0x10 and address <= FORCE_KILL_ADDR + 0x10 and dp.kill_me is not None: + if access == UC_MEM_FETCH_UNMAPPED and FORCE_KILL_ADDR - 0x10 <= address <= FORCE_KILL_ADDR + 0x10 and dp.kill_me is not None: dp.error(f"forced exit memory operation {access} of {hex(address)}[{hex(size)}] = {hex(value)}") return False - if dp.exception.final and access in fetch_accesses: + if dp._exception.final and access in fetch_accesses: dp.info(f"fetch from {hex(address)}[{size}] already reported") return False # TODO: figure out why when you start executing at 0 this callback is triggered more than once try: + violation = { + UC_MEM_READ_UNMAPPED: MemoryViolation.ReadUnmapped, + UC_MEM_WRITE_UNMAPPED: MemoryViolation.WriteUnmapped, + UC_MEM_FETCH_UNMAPPED: MemoryViolation.ExecuteUnmapped, + UC_MEM_READ_PROT: MemoryViolation.ReadProtect, + UC_MEM_WRITE_PROT: MemoryViolation.WriteProtect, + UC_MEM_FETCH_PROT: MemoryViolation.ExecuteProtect, + }.get(access, MemoryViolation.Unknown) + assert violation != MemoryViolation.Unknown, f"Unexpected memory access {access}" # Extract exception information - exception = ExceptionInfo() + exception = UnicornExceptionInfo() exception.type = ExceptionType.Memory - exception.memory_access = access + exception.memory_violation = violation exception.memory_address = address exception.memory_size = size exception.memory_value = value @@ -1269,7 +1312,7 @@ def _hook_mem(uc: Uc, access, address, size, value, dp: Dumpulator): exception.tb_icount = tb.icount # Print exception info - final = dp.trace or dp.exception.code_hook_h is not None + final = dp.trace or dp._exception.code_hook_h is not None info = "final" if final else "initial" if access == UC_MEM_READ_UNMAPPED: dp.error(f"{info} unmapped read from {hex(address)}[{hex(size)}], cip = {hex(dp.regs.cip)}, exception: {exception}") @@ -1295,25 +1338,25 @@ def _hook_mem(uc: Uc, access, address, size, value, dp: Dumpulator): if final: # Make sure this is the same exception we expect if not dp.trace: - assert access == dp.exception.memory_access - assert address == dp.exception.memory_address - assert size == dp.exception.memory_size - assert value == dp.exception.memory_value + assert violation == dp._exception.memory_violation + assert address == dp._exception.memory_address + assert size == dp._exception.memory_size + assert value == dp._exception.memory_value # Delete the code hook - uc.hook_del(dp.exception.code_hook_h) - dp.exception.code_hook_h = None + uc.hook_del(dp._exception.code_hook_h) + dp._exception.code_hook_h = None # At this point we know for sure the context is correct so we can report the exception - dp.exception = exception - dp.exception.final = True + dp._exception = exception + dp._exception.final = True # Stop emulation (we resume it on KiUserExceptionDispatcher later) dp.stop() return False # There should not be an exception active - assert dp.exception.type == ExceptionType.NoException + assert dp._exception.type == ExceptionType.NoException # Remove the translation block cache for this block # Without doing this single stepping the block won't work @@ -1325,7 +1368,7 @@ def _hook_mem(uc: Uc, access, address, size, value, dp: Dumpulator): exception.code_hook_h = uc.hook_add(UC_HOOK_CODE, _hook_code_exception, user_data=dp) # Store the exception info - dp.exception = exception + dp._exception = exception # Stop emulation (we resume execution later) dp.stop() @@ -1452,7 +1495,7 @@ def _arg_type_string(arg): def _hook_interrupt(uc: Uc, number, dp: Dumpulator): try: # Extract exception information - exception = ExceptionInfo() + exception = UnicornExceptionInfo() exception.type = ExceptionType.Interrupt exception.interrupt_number = number exception.context = uc.context_save() @@ -1470,11 +1513,11 @@ def _hook_interrupt(uc: Uc, number, dp: Dumpulator): dp.error(f"interrupt {number} ({description}), cip = {hex(dp.regs.cip)}, cs = {hex(dp.regs.cs)}") # There should not be an exception active - assert dp.exception.type == ExceptionType.NoException + assert dp._exception.type == ExceptionType.NoException # At this point we know for sure the context is correct so we can report the exception - dp.exception = exception - dp.exception.final = True + dp._exception = exception + dp._exception.final = True except AssertionError as err: traceback.print_exc() raise err @@ -1560,7 +1603,7 @@ def syscall_arg(index): status = syscall_impl(dp, *args) if isinstance(status, ExceptionInfo): print("context switch, stopping emulation") - dp.exception = status + dp._exception = status raise dp.raise_kill(UcError(UC_ERR_EXCEPTION)) from None else: dp.info(f"status = {hex(status)}") @@ -1610,11 +1653,11 @@ def _hook_invalid(uc: Uc, dp: Dumpulator): instr = next(dp.cs.disasm(code, address, 1)) if _emulate_unsupported_instruction(dp, instr): # Resume execution with a context switch - assert dp.exception.type == ExceptionType.NoException - exception = ExceptionInfo() + assert dp._exception.type == ExceptionType.NoException + exception = UnicornExceptionInfo() exception.type = ExceptionType.ContextSwitch exception.final = True - dp.exception = exception + dp._exception = exception return False # NOTE: returning True would stop emulation except StopIteration: pass # Unsupported instruction diff --git a/src/dumpulator/native.py b/src/dumpulator/native.py index f4b4594..712e9fe 100644 --- a/src/dumpulator/native.py +++ b/src/dumpulator/native.py @@ -8,6 +8,7 @@ # NTSTATUS STATUS_SUCCESS = 0 STATUS_NOT_IMPLEMENTED = 0xC0000002 +STATUS_ACCESS_VIOLATION = 0xC0000005 STATUS_INVALID_HANDLE = 0xC0000008 STATUS_NO_SUCH_FILE = 0xC000000F STATUS_ACCESS_DENIED = 0xC0000022 diff --git a/src/dumpulator/ntsyscalls.py b/src/dumpulator/ntsyscalls.py index 7edc7a7..aeb1035 100644 --- a/src/dumpulator/ntsyscalls.py +++ b/src/dumpulator/ntsyscalls.py @@ -2,7 +2,7 @@ import struct import unicorn -from .dumpulator import Dumpulator, syscall_functions, ExceptionInfo, ExceptionType +from .dumpulator import Dumpulator from .native import * from .handles import * from .memory import * @@ -13,6 +13,7 @@ def syscall(func): if name[:2] not in ["Zw", "Nt"]: raise Exception(f"All syscalls have to be prefixed with 'Zw' or 'Nt'") # Add the function with both prefixes to avoid name bugs + from .dumpulator import syscall_functions syscall_functions["Zw" + name[2:]] = func syscall_functions["Nt" + name[2:]] = func return func @@ -779,7 +780,10 @@ def ZwContinue(dp: Dumpulator, ): # Trigger a context switch assert not TestAlert - exception = ExceptionInfo() + + # TODO: move this to a dedicated helper method + from .dumpulator import UnicornExceptionInfo, ExceptionType + exception = UnicornExceptionInfo() exception.type = ExceptionType.ContextSwitch exception.final = True context_type = CONTEXT if dp.ptr_size() == 8 else WOW64_CONTEXT @@ -4519,7 +4523,10 @@ def ZwTerminateProcess(dp: Dumpulator, ): assert ProcessHandle == 0 or ProcessHandle == dp.NtCurrentProcess() dp.stop(ExitStatus) - exception = ExceptionInfo() + + # TODO: move this to a dedicated helper method + from .dumpulator import UnicornExceptionInfo, ExceptionType + exception = UnicornExceptionInfo() exception.type = ExceptionType.Terminate exception.final = True exception.context = dp._uc.context_save()