Skip to content

Commit

Permalink
Refactor DLL callbacks (#665)
Browse files Browse the repository at this point in the history
Cherry-picked from primedev and slightly modified

Co-authored-by: F1F7Y <[email protected]>
  • Loading branch information
ASpoonPlaysGames and F1F7Y authored Aug 18, 2024
1 parent a28c1cb commit 5c730b0
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 83 deletions.
2 changes: 2 additions & 0 deletions primedev/Northstar.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ add_library(
"util/version.h"
"util/wininfo.cpp"
"util/wininfo.h"
"windows/libsys.cpp"
"windows/libsys.h"
"dllmain.cpp"
"ns_version.h"
"Northstar.def"
Expand Down
85 changes: 4 additions & 81 deletions primedev/core/hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#include <filesystem>
#include <Psapi.h>

#define XINPUT1_3_DLL "XInput1_3.dll"

namespace fs = std::filesystem;

AUTOHOOK_INIT()
Expand Down Expand Up @@ -392,87 +390,12 @@ void CallAllPendingDLLLoadCallbacks()
}
}

// clang-format off
AUTOHOOK_ABSOLUTEADDR(_LoadLibraryExA, (LPVOID)LoadLibraryExA,
HMODULE, WINAPI, (LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags))
// clang-format on
{
HMODULE moduleAddress;

LPCSTR lpLibFileNameEnd = lpLibFileName + strlen(lpLibFileName);
LPCSTR lpLibName = lpLibFileNameEnd - strlen(XINPUT1_3_DLL);

// replace xinput dll with one that has ASLR
if (lpLibFileName <= lpLibName && !strncmp(lpLibName, XINPUT1_3_DLL, strlen(XINPUT1_3_DLL) + 1))
{
moduleAddress = _LoadLibraryExA("XInput9_1_0.dll", hFile, dwFlags);

if (!moduleAddress)
{
MessageBoxA(0, "Could not find XInput9_1_0.dll", "Northstar", MB_ICONERROR);
exit(EXIT_FAILURE);

return nullptr;
}
}
else
moduleAddress = _LoadLibraryExA(lpLibFileName, hFile, dwFlags);

if (moduleAddress)
{
CallLoadLibraryACallbacks(lpLibFileName, moduleAddress);
g_pPluginManager->InformDllLoad(moduleAddress, fs::path(lpLibFileName));
}

return moduleAddress;
}

// clang-format off
AUTOHOOK_ABSOLUTEADDR(_LoadLibraryA, (LPVOID)LoadLibraryA,
HMODULE, WINAPI, (LPCSTR lpLibFileName))
// clang-format on
{
HMODULE moduleAddress = _LoadLibraryA(lpLibFileName);

if (moduleAddress)
CallLoadLibraryACallbacks(lpLibFileName, moduleAddress);

return moduleAddress;
}

// clang-format off
AUTOHOOK_ABSOLUTEADDR(_LoadLibraryExW, (LPVOID)LoadLibraryExW,
HMODULE, WINAPI, (LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags))
// clang-format on
{
HMODULE moduleAddress = _LoadLibraryExW(lpLibFileName, hFile, dwFlags);

if (moduleAddress)
CallLoadLibraryWCallbacks(lpLibFileName, moduleAddress);

return moduleAddress;
}

// clang-format off
AUTOHOOK_ABSOLUTEADDR(_LoadLibraryW, (LPVOID)LoadLibraryW,
HMODULE, WINAPI, (LPCWSTR lpLibFileName))
// clang-format on
{
HMODULE moduleAddress = _LoadLibraryW(lpLibFileName);

if (moduleAddress)
{
CallLoadLibraryWCallbacks(lpLibFileName, moduleAddress);
g_pPluginManager->InformDllLoad(moduleAddress, fs::path(lpLibFileName));
}

return moduleAddress;
}

void InstallInitialHooks()
void HookSys_Init()
{
if (MH_Initialize() != MH_OK)
{
spdlog::error("MH_Initialize (minhook initialization) failed");

}
// todo: remove remaining instances of autohook in this file
AUTOHOOK_DISPATCH()
}
28 changes: 27 additions & 1 deletion primedev/core/hooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,33 @@
#include <string>
#include <iostream>

void InstallInitialHooks();
//-----------------------------------------------------------------------------
// Purpose: Init minhook
//-----------------------------------------------------------------------------
void HookSys_Init();

//-----------------------------------------------------------------------------
// Purpose: MH_MakeHook wrapper
// Input : *ppOriginal - Original function being detoured
// pDetour - Detour function
//-----------------------------------------------------------------------------
inline void HookAttach(PVOID* ppOriginal, PVOID pDetour)
{
PVOID pAddr = *ppOriginal;
if (MH_CreateHook(pAddr, pDetour, ppOriginal) == MH_OK)
{
if (MH_EnableHook(pAddr) != MH_OK)
{
spdlog::error("Failed enabling a function hook!");
}
}
else
{
spdlog::error("Failed creating a function hook!");
}
}

void CallLoadLibraryACallbacks(LPCSTR lpLibFileName, HMODULE moduleAddress);

typedef void (*DllLoadCallbackFuncType)(CModule moduleAddress);
void AddDllLoadCallback(std::string dll, DllLoadCallbackFuncType callback, std::string tag = "", std::vector<std::string> reliesOn = {});
Expand Down
8 changes: 7 additions & 1 deletion primedev/dllmain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "squirrel/squirrel.h"
#include "server/serverpresence.h"

#include "windows/libsys.h"

#include "rapidjson/document.h"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/writer.h"
Expand Down Expand Up @@ -64,7 +66,11 @@ bool InitialiseNorthstar()
// Write launcher version to log
StartupLog();

InstallInitialHooks();
// Init minhook
HookSys_Init();

// Init loadlibrary callbacks
LibSys_Init();

g_pServerPresence = new ServerPresenceManager();

Expand Down
123 changes: 123 additions & 0 deletions primedev/windows/libsys.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include "libsys.h"
#include "plugins/pluginmanager.h"

#define XINPUT1_3_DLL "XInput1_3.dll"

typedef HMODULE (*WINAPI ILoadLibraryA)(LPCSTR lpLibFileName);
typedef HMODULE (*WINAPI ILoadLibraryExA)(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);
typedef HMODULE (*WINAPI ILoadLibraryW)(LPCWSTR lpLibFileName);
typedef HMODULE (*WINAPI ILoadLibraryExW)(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);

ILoadLibraryA o_LoadLibraryA = nullptr;
ILoadLibraryExA o_LoadLibraryExA = nullptr;
ILoadLibraryW o_LoadLibraryW = nullptr;
ILoadLibraryExW o_LoadLibraryExW = nullptr;

//-----------------------------------------------------------------------------
// Purpose: Run detour callbacks for given HMODULE
//-----------------------------------------------------------------------------
void LibSys_RunModuleCallbacks(HMODULE hModule)
{
if (!hModule)
{
return;
}

// Get module base name in ASCII as noone wants to deal with unicode
CHAR szModuleName[MAX_PATH];
GetModuleBaseNameA(GetCurrentProcess(), hModule, szModuleName, MAX_PATH);

// DevMsg(eLog::NONE, "%s\n", szModuleName);

// Call callbacks
CallLoadLibraryACallbacks(szModuleName, hModule);
g_pPluginManager->InformDllLoad(hModule, fs::path(szModuleName));
}

//-----------------------------------------------------------------------------
// Load library callbacks

HMODULE WINAPI WLoadLibraryA(LPCSTR lpLibFileName)
{
HMODULE hModule = o_LoadLibraryA(lpLibFileName);

LibSys_RunModuleCallbacks(hModule);

return hModule;
}

HMODULE WINAPI WLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
{
HMODULE hModule;

LPCSTR lpLibFileNameEnd = lpLibFileName + strlen(lpLibFileName);
LPCSTR lpLibName = lpLibFileNameEnd - strlen(XINPUT1_3_DLL);

// replace xinput dll with one that has ASLR
if (lpLibFileName <= lpLibName && !strncmp(lpLibName, XINPUT1_3_DLL, strlen(XINPUT1_3_DLL) + 1))
{
hModule = o_LoadLibraryExA("XInput9_1_0.dll", hFile, dwFlags);

if (!hModule)
{
MessageBoxA(0, "Could not find XInput9_1_0.dll", "Northstar", MB_ICONERROR);
exit(EXIT_FAILURE);

return nullptr;
}
}
else
{
hModule = o_LoadLibraryExA(lpLibFileName, hFile, dwFlags);
}

bool bShouldRunCallbacks =
!(dwFlags & (LOAD_LIBRARY_AS_DATAFILE | LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE | LOAD_LIBRARY_AS_IMAGE_RESOURCE));
if (bShouldRunCallbacks)
{
LibSys_RunModuleCallbacks(hModule);
}

return hModule;
}

HMODULE WINAPI WLoadLibraryW(LPCWSTR lpLibFileName)
{
HMODULE hModule = o_LoadLibraryW(lpLibFileName);

LibSys_RunModuleCallbacks(hModule);

return hModule;
}

HMODULE WINAPI WLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
{
HMODULE hModule = o_LoadLibraryExW(lpLibFileName, hFile, dwFlags);

bool bShouldRunCallbacks =
!(dwFlags & (LOAD_LIBRARY_AS_DATAFILE | LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE | LOAD_LIBRARY_AS_IMAGE_RESOURCE));
if (bShouldRunCallbacks)
{
LibSys_RunModuleCallbacks(hModule);
}

return hModule;
}

//-----------------------------------------------------------------------------
// Purpose: Initilase dll load callbacks
//-----------------------------------------------------------------------------
void LibSys_Init()
{
HMODULE hKernel = GetModuleHandleA("Kernel32.dll");

o_LoadLibraryA = reinterpret_cast<ILoadLibraryA>(GetProcAddress(hKernel, "LoadLibraryA"));
o_LoadLibraryExA = reinterpret_cast<ILoadLibraryExA>(GetProcAddress(hKernel, "LoadLibraryExA"));
o_LoadLibraryW = reinterpret_cast<ILoadLibraryW>(GetProcAddress(hKernel, "LoadLibraryW"));
o_LoadLibraryExW = reinterpret_cast<ILoadLibraryExW>(GetProcAddress(hKernel, "LoadLibraryExW"));

HookAttach(&(PVOID&)o_LoadLibraryA, (PVOID)WLoadLibraryA);
HookAttach(&(PVOID&)o_LoadLibraryExA, (PVOID)WLoadLibraryExA);
HookAttach(&(PVOID&)o_LoadLibraryW, (PVOID)WLoadLibraryW);
HookAttach(&(PVOID&)o_LoadLibraryExW, (PVOID)WLoadLibraryExW);
}
3 changes: 3 additions & 0 deletions primedev/windows/libsys.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#pragma once

void LibSys_Init();

0 comments on commit 5c730b0

Please sign in to comment.