diff --git a/src/commands_builtin.cpp b/src/commands_builtin.cpp index f2e6751..1359c4c 100644 --- a/src/commands_builtin.cpp +++ b/src/commands_builtin.cpp @@ -19,29 +19,8 @@ using ra2yrcpp::command::get_cmd; -/// Adds two unsigned integers, and returns result in EAX -struct TestProgram : Xbyak::CodeGenerator { - TestProgram() { - mov(eax, ptr[esp + 1 * 0x4]); - add(eax, ptr[esp + 2 * 0x4]); - entry_size = getSize(); - ret(); - } - - auto get_code() { return getCode(); } - - std::size_t entry_size; -}; - -static void test_cb(hook::Hook*, void* data, X86Regs*) { - auto I = static_cast(data); - std::string s("0xbeefdead"); - I->store_value("test_key", s.begin(), s.end()); -} - -// TODO(shmocz): ditch the old hook/cb test functions to use the common -// functions -std::map get_commands_nn() { +std::map +ra2yrcpp::commands_builtin::get_commands() { return { get_cmd([](auto* Q) { // NB: ensure correct radix @@ -75,55 +54,5 @@ std::map get_commands_nn() { c.set_value(ra2yrcpp::to_string( *reinterpret_cast(Q->I()->get_value(c.key(), false)))); }), - get_cmd([](auto* Q) { - static TestProgram t; - auto t_addr = t.get_code(); - t_addr(3, 3); - - auto& res = Q->command_data(); - res.set_address_test_function(reinterpret_cast(t_addr)); - res.set_address_test_callback(reinterpret_cast(&test_cb)); - res.set_code_size(t.entry_size); - }), - get_cmd([](auto* Q) { - auto& a = Q->command_data(); - hook::HookCallback CB{.func = reinterpret_cast( - a.callback_address()), - .user_data = Q->I()}; - Q->I() - ->hooks() - .at(static_cast(a.hook_address())) - .add_callback(CB); - }), - get_cmd([](auto* Q) { -// TODO(shmocz): put these to utility function and share code with -// Hook code. -#ifdef _WIN32 - auto P = process::get_current_process(); - std::vector ns(Q->I()->get_connection_threads()); - - auto& a = Q->command_data(); - - // suspend threads? - if (!a.no_suspend_threads()) { - ns.push_back(process::get_current_tid()); - P.suspend_threads(ns); - } - - // create hooks - for (auto& h : a.hooks()) { - Q->I()->create_hook(h.name(), - static_cast(h.address()), - h.code_length()); - } - if (!a.no_suspend_threads()) { - P.resume_threads(ns); - } -#endif - })}; -} - -std::map -ra2yrcpp::commands_builtin::get_commands() { - return get_commands_nn(); + }; } diff --git a/src/commands_yr.cpp b/src/commands_yr.cpp index 07dc4de..fba3243 100644 --- a/src/commands_yr.cpp +++ b/src/commands_yr.cpp @@ -5,7 +5,6 @@ #include "command/is_command.hpp" #include "config.hpp" -#include "errors.hpp" #include "hooks_yr.hpp" #include "logging.hpp" #include "protocol/helpers.hpp" @@ -25,8 +24,6 @@ using ra2yrcpp::command::get_async_cmd; using ra2yrcpp::command::get_cmd; using ra2yrcpp::command::message_result; -using ra2yrcpp::hooks_yr::ensure_storage_value; -using ra2yrcpp::hooks_yr::get_data; using ra2yrcpp::hooks_yr::get_gameloop_command; // TODO(shmocz): don't allow deploying of already deployed object @@ -96,66 +93,32 @@ auto unit_command() { }); } -auto create_callbacks() { - return get_cmd([](auto* Q) { - auto [lk_s, s] = Q->I()->aq_storage(); - // Create main game data structure - // TODO(shmocz): initialize elsewhere - ra2yrcpp::hooks_yr::init_callbacks(get_data(Q->I())); - auto cbs = ra2yrcpp::hooks_yr::get_callbacks(Q->I()); - auto [lk, hhooks] = Q->I()->aq_hooks(); - for (auto& [k, v] : *cbs) { - auto target = v->target(); - auto h = std::find_if(hhooks->begin(), hhooks->end(), [&](auto& a) { - return (a.second.name() == target); - }); - if (h == hhooks->end()) { - throw std::runtime_error(fmt::format("No such hook {}", target)); - } - - const std::string hook_name = k; - auto& tmp_cbs = h->second.callbacks(); - // TODO(shmocz): throw standard exception - if (std::find_if(tmp_cbs.begin(), tmp_cbs.end(), [&hook_name](auto& a) { - return a.name == hook_name; - }) != tmp_cbs.end()) { - throw std::runtime_error(fmt::format( - "Hook {} already has a callback {}", target, hook_name)); - } - - iprintf("add callback, target={} cb={}", target, hook_name); - auto cb = v.get(); - cb->add_to_hook(&h->second, Q->I()); - } - }); -} - auto get_game_state() { return get_cmd([](auto* Q) { - Q->I()->lock_storage(); - auto* D = get_data(Q->I()); + auto& M = ra2yrcpp::hooks_yr::MainData::get(); + M.lock(); + auto* D = M.data(); // Unpause game if single-step mode. if (D->cfg.single_step() && D->game_paused.get()) { - Q->I()->unlock_storage(); + M.unlock(); D->game_paused.wait(true); - Q->I()->lock_storage(); + M.lock(); } Q->command_data().mutable_state()->CopyFrom(D->sv.game_state()); if (D->cfg.single_step()) { D->game_paused.store(false); } - Q->I()->unlock_storage(); + M.unlock(); }); } auto inspect_configuration() { return get_cmd([](auto* Q) { - auto [mut, s] = Q->I()->aq_storage(); + auto [mut, M] = ra2yrcpp::hooks_yr::MainData::acquire(); auto& res = Q->command_data(); - auto* cfg = &ra2yrcpp::hooks_yr::get_data(Q->I())->cfg; - cfg->MergeFrom(Q->command_data().config()); - res.mutable_config()->CopyFrom(*cfg); + M->update_configuration(res.config()); + res.mutable_config()->CopyFrom(M->data()->cfg); }); } @@ -294,7 +257,7 @@ static void convert_map_data(ra2yrproto::ra2yr::MapDataSoA* dst, /// Read a protobuf message from storage determined by command argument type. auto read_value() { return get_cmd([](auto* Q) { - auto [mut, s] = Q->I()->aq_storage(); + auto [mut, M] = ra2yrcpp::hooks_yr::MainData::acquire(); auto& A = Q->command_data(); // find the first field that's been set auto sf = ra2yrcpp::protocol::find_set_fields(A.data()); @@ -306,10 +269,10 @@ auto read_value() { if (fld->name() == "map_data_soa") { convert_map_data(D->mutable_map_data_soa(), - get_data(Q->I())->sv.mutable_map_data()); + M->data()->sv.mutable_map_data()); } else { // TODO(shmocz): use oneof - ra2yrcpp::protocol::copy_field(D, &get_data(Q->I())->sv, fld); + ra2yrcpp::protocol::copy_field(D, &M->data()->sv, fld); } }); } @@ -321,7 +284,6 @@ commands_yr::get_commands() { return { cmd::click_event(), // cmd::unit_command(), // - cmd::create_callbacks(), // cmd::get_game_state(), // cmd::inspect_configuration(), // cmd::mission_clicked(), // diff --git a/src/hook.cpp b/src/hook.cpp index 1cf7962..7e88223 100644 --- a/src/hook.cpp +++ b/src/hook.cpp @@ -1,43 +1,38 @@ #include "hook.hpp" #include "logging.hpp" -#include "utility/time.hpp" +#include "process.hpp" #include "x86.hpp" #include -#include -#include - using namespace hook; -using namespace std::chrono_literals; -DetourMain::DetourMain(addr_t target, addr_t hook, std::size_t code_length, - addr_t call_hook, unsigned int* count_enter, - unsigned int* count_exit) { +static void patch_code(u8* target_address, const u8* code, + std::size_t code_length) { + dprintf("address={}, bytes={}", reinterpret_cast(target_address), + code_length); + auto P = process::get_current_process(); + P.write_memory(target_address, code, code_length); +} + +// This differs from Syringe, which uses relative call +SyringeHook::SyringeHook(addr_t target, addr_t hook_function, + std::size_t code_length) { nop(code_length, false); // placeholder for original instruction(s) x86::save_regs(this); - push(hook); - mov(eax, call_hook); - lock(); - inc(dword[count_enter]); + push(esp); + mov(eax, hook_function); call(eax); - lock(); - inc(dword[count_exit]); add(esp, 0x4); x86::restore_regs(this); push(target + code_length); ret(); } -DetourMain::DetourMain(Hook* h) - : DetourMain(h->detour().src_address, reinterpret_cast(h), - h->detour().code_length, reinterpret_cast(&h->call), - h->count_enter(), h->count_exit()) {} - // TODO: fail if code is too short -struct DetourTrampoline : Xbyak::CodeGenerator { - DetourTrampoline(const u8* target, std::size_t code_length) { +struct JumpTo : Xbyak::CodeGenerator { + JumpTo(const u8* target, std::size_t code_length) { push(reinterpret_cast(target)); ret(); const std::size_t pad_length = code_length - getSize(); @@ -48,163 +43,16 @@ struct DetourTrampoline : Xbyak::CodeGenerator { } }; -unsigned int num_threads_at_tgt(const process::Process& P, const u8* target, - std::size_t length) { - auto main_tid = process::get_current_tid(); - std::vector ips; - P.for_each_thread([&ips, &main_tid](process::Thread* T, void* ctx) { - (void)ctx; - if (T->id() != main_tid) { - ips.push_back(*T->get_pgpr(x86Reg::eip)); - } - }); - unsigned int res = 0; - const auto t = reinterpret_cast(target); - for (auto eip : ips) { - dprintf("eip,beg,end={},{},{}", eip, t, - t + static_cast(length)); - if (eip >= t && (eip < t + length)) { - ++res; - } - } - return res; -} - -Hook::Hook(addr_t src_address, std::size_t code_length, std::string name, - std::vector no_suspend, bool manual) - : d_{src_address, 0u, code_length}, - name_(name), - dm_(this), - no_suspend_(no_suspend), - count_enter_(0u), - count_exit_(0u), - manual_(manual) { - // Create detour - auto p = dm_.getCode(); - - // Copy original instruction to detour - patch_code(p, reinterpret_cast(src_address), code_length); - - // Patch target region - DetourTrampoline D(p, code_length); - auto f = D.getCode(); - if (manual_) { - patch_code(reinterpret_cast(src_address), f, D.getSize()); - } else { - patch_code_safe(reinterpret_cast(src_address), f, D.getSize()); - } -} - -void threads_resume_wait_pause(const process::Process& P, - duration_t m = 0.01s) { - auto main_tid = process::get_current_tid(); - P.resume_threads(main_tid); - util::sleep_ms(m); - P.suspend_threads(main_tid); -} - -Hook::~Hook() { - // Remove all callbacks - lock(); - callbacks().clear(); - unlock(); - - // Patch back original code - auto p = dm_.getCode(); - patch_code_safe(reinterpret_cast(detour().src_address), p, - detour().code_length); - auto P = process::get_current_process(); - // Wait until all threads have exited the hook - const auto main_tid = process::get_current_tid(); - P.suspend_threads(main_tid); - while (num_threads_at_tgt(P, p, detour().code_length) > 0 || - (*count_enter() != *count_exit())) { - threads_resume_wait_pause(P); - } - P.resume_threads(main_tid); -} - -void Hook::add_callback(HookCallback c) { - lock(); - callbacks_.push_back(c); - unlock(); -} - -void Hook::add_callback(std::function func, - void* user_data, std::string name) { - add_callback(HookCallback{func, user_data, 0U, name}); -} - -void Hook::call(Hook* H, X86Regs state) { - H->lock(); - unsigned off = 0u; - auto& C = H->callbacks(); - for (auto i = 0u; i < C.size(); i++) { - auto ix = i - off; - auto& c = C.at(ix); - c.func(H, c.user_data, &state); - c.calls += 1; - } - H->unlock(); -} - -std::vector& Hook::callbacks() { return callbacks_; } - -void Hook::lock() { mu_.lock(); } - -void Hook::unlock() { mu_.unlock(); } - -Detour& Hook::detour() { return d_; } - -const std::string& Hook::name() const { return name_; } - -void Hook::patch_code(u8* target_address, const u8* code, - std::size_t code_length) { - dprintf("address={}, bytes={}", reinterpret_cast(target_address), - code_length); - auto P = process::get_current_process(); - P.write_memory(target_address, code, code_length); -} - -void Hook::patch_code_safe(u8* target_address, const u8* code, - std::size_t code_length) { - auto P = process::get_current_process(); - auto main_tid = process::get_current_tid(); - - dprintf("suspending, tgt={}, code={}, len={}, main tid={}", - reinterpret_cast(target_address), - reinterpret_cast(code), code_length, main_tid); - auto ns = std::vector(no_suspend_); - ns.push_back(main_tid); - P.suspend_threads(ns); - dprintf("suspend done"); - // FIXME: broken! completely ignores no_suspend_ - // Wait until no thread is at target region - while (num_threads_at_tgt(P, target_address, code_length) > 0) { - dprintf("waiting until thread exits target region.."); - threads_resume_wait_pause(P); - } - patch_code(target_address, code, code_length); - P.resume_threads(ns); -} - -unsigned int* Hook::count_enter() { return &count_enter_; } - -unsigned int* Hook::count_exit() { return &count_exit_; } - -void Hook::remove_callback(std::string name) { - auto it = std::find_if(callbacks_.begin(), callbacks_.end(), - [&name](const auto& j) { return j.name == name; }); - if (it == callbacks_.end()) { - throw std::runtime_error(fmt::format("Callback not found: {}", name)); - } - callbacks_.erase(it); -} +Hook::Hook(HookEntry h, hook_fn fn) + : trampoline_(h.address, reinterpret_cast(fn), h.size) { + // Create the main hook prologue + // TODO: use correct signature + auto* p = trampoline_.getCode(); -Callback::Callback() : cpu_state(nullptr) {} + // Copy original instructions to prologue + patch_code(p, reinterpret_cast(h.address), h.size); -void Callback::call(hook::Hook*, void*, X86Regs* state) { - cpu_state = state; - do_call(); - cpu_state = nullptr; + // Patch target code with jump to hook prologue + JumpTo D(p, h.size); + patch_code(reinterpret_cast(h.address), D.getCode(), D.getSize()); } diff --git a/src/hook.hpp b/src/hook.hpp index d3a1871..4b8ce36 100644 --- a/src/hook.hpp +++ b/src/hook.hpp @@ -1,6 +1,5 @@ #pragma once -#include "process.hpp" #include "types.h" #include @@ -8,125 +7,53 @@ #include #undef ERROR #undef OK -#include -#include -#include -#include namespace hook { -constexpr unsigned DETOUR_MAX_SIZE = 128U; -constexpr u8 OP_PUSH = 0x68; -constexpr u8 OP_RET = 0xc3; +#pragma pack(push, 16) -using process::thread_id_t; - -/// Represents control flow redirection at particular location. -struct Detour { - addr_t src_address; - addr_t detour_address; - std::size_t code_length; +struct HookEntry { + u32 address; + u32 size; + const char* name; }; -class Hook; - -struct DetourMain : Xbyak::CodeGenerator { - DetourMain(addr_t target, addr_t hook, std::size_t code_length, - addr_t call_hook, unsigned int* count_enter, - unsigned int* count_exit); - - explicit DetourMain(Hook* h); -}; +#pragma pack(pop) -struct HookCallback { - std::function func; - void* user_data; +using hook_fn = u32 __cdecl (*)(X86Regs*); - // How many times the callback has been invoked. - unsigned calls{0u}; - std::string name{""}; +/// In Syringe hook format, the overwritten instruction is executed after hook. +/// NB: We don't need to obey syringe prologue rigorously. As long export the +/// hook functions in Syringe compatible format, and ensure that REGISTER +/// argument is passed correctly, we should be fine. Only when ra2yrcpp is +/// executed without Syringe (e.g.) with custom loader code, the internal +/// prologue format should be used. +struct SyringeHook : Xbyak::CodeGenerator { + SyringeHook(addr_t target, addr_t hook_function, std::size_t code_length); }; -/// -/// Install a hook into memory location. This is implementented as a -/// detour that executes user supplied callbacks. When the object is destroyed, -/// the hook is uninstalled with all it's allocated code regions freed. -/// -/// On WIN32, detour is installed by suspending all threads, allocating -/// executable memory region, writing the detour "trampoline" to it and patching -/// the code region at target address. If thread's instruction pointer is in -/// target region, suspend/resume is called repeatedly until it has exited -/// the region. Uninstallation is done by first removing all callbacks, -/// suspending all threads, restoring original code and freeing the detour -/// memory block. Detour trampoline is removed in same way as it was originally -/// installed. Once all threads have exited from the detour region, code region -/// is freed and Hook object destroyed. -/// class Hook { public: - typedef void (*hook_cb_t)(Hook* h, void* user_data, X86Regs* state); - - /// - /// @param src_address Address that will be hooked - /// @param code_length Number of bytes to copy to detour's location - /// @param name (Optional) Name of the hook - /// @param no_suspend (Optional) List of threads to not suspend during - /// patching (in addition to current thread id) - /// TODO: move constructor - /// - Hook(addr_t src_address, std::size_t code_length, std::string name = "", - std::vector no_suspend = {}, bool manual = false); - ~Hook(); - void add_callback(HookCallback c); - void add_callback(std::function func, - void* user_data, std::string name); - - /// Invoke all registered hook functions. This function is thread safe. - static void __cdecl call(Hook* H, X86Regs state); - std::vector& callbacks(); - - void lock(); - void unlock(); - Detour& detour(); - const std::string& name() const; - void patch_code(u8* target_address, const u8* code, std::size_t code_length); - - /// Wait until no thread is in target region, then patch code. - void patch_code_safe(u8* target_address, const u8* code, - std::size_t code_length); - /// Pointer to counter for enters to Hook::call. - unsigned int* count_enter(); - /// Pointer to counter for exits from Hook::call. - unsigned int* count_exit(); - /// Check if hook has a callback identified by name - void remove_callback(std::string name); + Hook(HookEntry h, hook_fn fn); private: - Detour d_; - std::string name_; - std::vector callbacks_; - std::mutex mu_; - DetourMain dm_; - std::vector no_suspend_; - unsigned int count_enter_; - unsigned int count_exit_; - bool manual_; -}; - -struct Callback { - Callback(); - Callback(const Callback&) = delete; - Callback& operator=(const Callback&) = delete; - virtual ~Callback() = default; - /// The function that will be stored in the HookCallback object - void call(hook::Hook* h, void* data, X86Regs* state); - /// Subclasses implement their callback logic by overriding this. - virtual void do_call() = 0; - /// Callback's name. Duplicate callbacks will not be added into Hook. - virtual std::string name() = 0; - /// Target hook name. - virtual std::string target() = 0; - X86Regs* cpu_state; + SyringeHook trampoline_; }; } // namespace hook + +#ifdef _MSC_VER +#define HOOK_SECTION_ENTRY(hook, funcname, size) \ + __declspec(allocate(".syhks00")) \ + HookEntry _hk_##hook##funcname = {hook, size, #funcname} +#else +#define HOOK_SECTION_ENTRY(hook, funcname, size) \ + HookEntry __attribute__((section(".syhks00"))) \ + _hk_##hook##funcname = {hook, size, #funcname} +#endif + +// NB. Syringe headers specify the HOOK_SECTION_ENTRY inside SyringeData::Hooks +// namespace. +#define DEFINE_HOOK(hook, funcname, size) \ + HOOK_SECTION_ENTRY(hook, funcname, size); \ + extern "C" __declspec(dllexport) DWORD __cdecl funcname(void* R) diff --git a/src/hooks_yr.cpp b/src/hooks_yr.cpp index 22b1b6d..1684bbf 100644 --- a/src/hooks_yr.cpp +++ b/src/hooks_yr.cpp @@ -7,13 +7,14 @@ #include "config.hpp" #include "hook.hpp" #include "instrumentation_service.hpp" +#include "is_context.hpp" #include "logging.hpp" #include "protocol/helpers.hpp" #include "ra2/abi.hpp" #include "ra2/state_context.hpp" #include "ra2/state_parser.hpp" #include "ra2/yrpp_export.hpp" -#include "utility/serialize.hpp" +#include "types.h" #include #include @@ -23,16 +24,22 @@ #include #include -#include +#include #include -#include #include -#include #include +#include #include +#include #include #include +#ifdef _MSC_VER +#pragma section(".syhks00", read, write) +#endif + +using hook::HookEntry; + using namespace ra2yrcpp::hooks_yr; using namespace std::chrono_literals; @@ -44,139 +51,86 @@ static auto default_configuration() { return C; } -GameDataYR::GameDataYR() : cfg(default_configuration()) { - ctx = std::make_unique(&abi, &sv); +static auto load_configuration() { + auto C = default_configuration(); + auto ts = std::to_string(static_cast( + std::chrono::high_resolution_clock::now().time_since_epoch().count())); + char* p = nullptr; + std::string record_path, traffic_path; + if ((p = std::getenv("RA2YRCPP_RECORD_PATH")) != nullptr) { + record_path = p; + if (record_path.empty()) { + record_path = fmt::format("record.{}.pb.gz", ts); + } + } + if ((p = std::getenv("RA2YRCPP_RECORD_TRAFFIC")) != nullptr) { + traffic_path = p; + if (traffic_path.empty()) { + traffic_path = fmt::format("traffic.{}.pb.gz", ts); + } + } + C.set_record_filename(record_path); + C.set_traffic_filename(traffic_path); + return C; } -cb_map_t* ra2yrcpp::hooks_yr::get_callbacks( - ra2yrcpp::InstrumentationService* I) { - return &get_data(I)->callbacks; +GameDataYR::GameDataYR() : cfg(load_configuration()) { + ctx = std::make_unique(&abi, &sv); } -CBYR::CBYR() {} - -ra2::abi::ABIGameMD* CBYR::abi() { return &data()->abi; } - -void CBYR::do_call() { - I->lock_storage(); - auto [mut_cc, cc] = abi()->acquire_code_generators(); - try { - exec(); - } catch (const std::exception& e) { - eprintf("{}: {}", name(), e.what()); - } - I->unlock_storage(); -} +ra2::abi::ABIGameMD* GameDataInterface::abi() { return &data()->abi; } -ra2yrproto::ra2yr::GameState* CBYR::game_state() { +ra2yrproto::ra2yr::GameState* GameDataInterface::game_state() { return data()->sv.mutable_game_state(); } -CBYR::tc_t* CBYR::type_classes() { +GameDataInterface::tc_t* GameDataInterface::type_classes() { return data()->sv.mutable_initial_game_state()->mutable_object_types(); } -ra2::StateContext* CBYR::get_state_context() { return data()->ctx.get(); } +ra2::StateContext* GameDataInterface::get_state_context() { + return data()->ctx.get(); +} + +ra2yrcpp::hooks_yr::GameDataYR* GameDataInterface::data() { + if (data_ == nullptr) { + data_ = MainData::get().data(); + } + return data_; +} -ra2yrcpp::hooks_yr::GameDataYR* CBYR::data() { - return data_ != nullptr ? data_ : get_data(I); +void MainData::update_configuration( + const ra2yrproto::commands::Configuration& C) { + auto* cfg = &data()->cfg; + if (C.parse_map_data_interval() > 0U) { + cfg->set_parse_map_data_interval(C.parse_map_data_interval()); + } + cfg->set_single_step(C.single_step()); } -auto* CBYR::prerequisite_groups() { +ra2yrproto::ra2yr::PrerequisiteGroups* +GameDataInterface::prerequisite_groups() { return data()->sv.mutable_initial_game_state()->mutable_prerequisite_groups(); } -ra2yrproto::commands::Configuration* CBYR::configuration() { +ra2yrproto::commands::Configuration* GameDataInterface::configuration() { return &data()->cfg; } -// TODO(shmocz): do the callback initialization later -struct CBExitGameLoop final - : public MyCB { - static constexpr char key_target[] = "on_gameloop_exit"; - static constexpr char key_name[] = "gameloop_exit"; - - CBExitGameLoop() = default; - CBExitGameLoop(const CBExitGameLoop& o) = delete; - CBExitGameLoop& operator=(const CBExitGameLoop& o) = delete; - CBExitGameLoop(CBExitGameLoop&& o) = delete; - CBExitGameLoop& operator=(CBExitGameLoop&& o) = delete; - ~CBExitGameLoop() override = default; - - void do_call() override { - // Delete all callbacks except ourselves - // NB. the corresponding HookCallback must be removed from Hook object - // (shared_ptr would be handy here) - auto [mut, s] = I->aq_storage(); - get_data(I)->sv.mutable_game_state()->set_stage( - ra2yrproto::ra2yr::STAGE_EXIT_GAME); - - auto [lk, hhooks] = I->aq_hooks(); - auto* callbacks = get_callbacks(I); - // Loop through all callbacks - std::vector keys; - std::transform(callbacks->begin(), callbacks->end(), - std::back_inserter(keys), - [](const auto& v) { return v.first; }); - - for (const auto& k : keys) { - if (k == name()) { - continue; - } - // Get corresponding hook - auto h = std::find_if(hhooks->begin(), hhooks->end(), [&](auto& a) { - return (a.second.name() == callbacks->at(k)->target()); - }); - // Remove callback's reference from Hook - if (h == hhooks->end()) { - eprintf("no hook found for callback {}", k); - } else { - h->second.remove_callback(k); - // Delete callback object - callbacks->erase(k); - } - } - - // Flush output in case the process is not terminated gracefully. - std::cerr << std::flush; - std::cout << std::flush; - } -}; - -struct CBUpdateLoadProgress final : public MyCB { - static constexpr char key_name[] = "cb_progress_update"; - static constexpr char key_target[] = "on_progress_update"; - - CBUpdateLoadProgress() = default; - - void exec() override { - auto* B = ProgressScreenClass::Instance().PlayerProgresses; - - auto* sv = &data()->sv; - auto* local_state = sv->mutable_load_state(); - if (local_state->load_progresses().empty()) { - for (auto i = 0U; i < (sizeof(*B) / sizeof(B)); i++) { - local_state->add_load_progresses(0.0); - } - } - for (int i = 0; i < local_state->load_progresses().size(); i++) { - local_state->set_load_progresses(i, B[i]); - } - sv->mutable_game_state()->set_stage( - ra2yrproto::ra2yr::LoadStage::STAGE_LOADING); - } -}; +template +static T* get_service_data() { + return reinterpret_cast(MainData::get().service_datas().at(T::id).get()); +} -struct CBSaveState final : public MyCB { +struct StateSave : public GameDataInterface { ra2yrcpp::protocol::MessageOstream out; utility::worker_util> work; ra2yrproto::ra2yr::GameState* initial_state; std::vector cells; - static constexpr char key_name[] = "save_state"; - static constexpr char key_target[] = "on_frame_update"; + static constexpr auto id = ServiceDataId::STATE_SAVE; - explicit CBSaveState(std::shared_ptr record_stream) + explicit StateSave(std::shared_ptr record_stream) : out(record_stream, true), work([this](const auto& w) { this->serialize_state(*w.get()); }, 10U), initial_state(nullptr) {} @@ -258,18 +212,28 @@ struct CBSaveState final : public MyCB { return std::make_shared(*gbuf); } - void exec() override { - // enables event debug logs - // *reinterpret_cast(0xa8ed74) = 1; + static std::unique_ptr create(std::string record_path) { + std::shared_ptr record_out = nullptr; + if (!record_path.empty()) { + iprintf("record state to {}", record_path); + record_out = std::make_shared( + record_path, std::ios_base::out | std::ios_base::binary); + } + return std::make_unique(record_out); + } + + void execute() { auto st = state_to_protobuf(type_classes()->empty()); work.push(st); } + + static StateSave* get() { return get_service_data(); } }; -template -struct CBTunnel : public MyCB { - public: +struct SaveTrafficData : ServiceData { + static constexpr ServiceDataId id = ServiceDataId::RECORD_TRAFFIC; using writer_t = std::shared_ptr; + writer_t out; struct packet_buffer { void* data; @@ -280,9 +244,17 @@ struct CBTunnel : public MyCB { u32 destination; }; - writer_t out; + explicit SaveTrafficData(writer_t out) : out(out) {} - explicit CBTunnel(writer_t out) : out(out) {} + packet_buffer recv_buffer(const X86Regs* cpu_state) { + return {reinterpret_cast(cpu_state->ebp + 0x3f074), + static_cast(cpu_state->esi), 1U, 0U}; + } + + packet_buffer send_buffer(const X86Regs* cpu_state) { + return {reinterpret_cast(cpu_state->ecx), + static_cast(cpu_state->eax), 0U, 1U}; + } void write_packet(u32 source, u32 dest, const void* buf, std::size_t len) { // dprintf("source={} dest={}, buf={}, len={}", source, dest, buf, len); @@ -291,117 +263,168 @@ struct CBTunnel : public MyCB { P.set_destination(dest); P.mutable_data()->assign(static_cast(buf), len); if (!out->write(P)) { - throw std::runtime_error( - fmt::format("{} write_packet failed", D::key_name)); + throw std::runtime_error("write_packet failed"); } } - virtual packet_buffer buffer() = 0; - - void exec() override { - auto b = buffer(); - if (b.size > 0) { + void write_packet(packet_buffer b) { + if (out != nullptr && b.size > 0) { write_packet(b.source, b.destination, b.data, b.size); } } -}; -// TODO(shmocz): pass smart ptr by reference? -struct CBTunnelRecvFrom final : public CBTunnel { - static constexpr char key_target[] = "cb_tunnel_recvfrom"; - static constexpr char key_name[] = "tunnel_recvfrom"; - - explicit CBTunnelRecvFrom(writer_t out) : CBTunnel(std::move(out)) {} + static std::unique_ptr create(std::string traffic_out) { + writer_t out = nullptr; + if (!traffic_out.empty()) { + out = std::make_shared( + std::make_shared( + traffic_out, std::ios_base::out | std::ios_base::binary), + true); + iprintf("record traffic to {}", traffic_out); + } - packet_buffer buffer() override { - return {reinterpret_cast(cpu_state->ebp + 0x3f074), - static_cast(cpu_state->esi), 1U, 0U}; + return std::make_unique(out); } + + static SaveTrafficData* get() { return get_service_data(); } }; -struct CBTunnelSendTo final : public CBTunnel { - static constexpr char key_target[] = "cb_tunnel_sendto"; - static constexpr char key_name[] = "tunnel_sendto"; +void MainData::initialize_service_datas() { + auto& C = data_->cfg; + service_datas_.try_emplace(ServiceDataId::STATE_SAVE, + StateSave::create(C.record_filename())); + service_datas_.try_emplace(ServiceDataId::GAME_COMMAND, + GameCommandData::create()); + service_datas_.try_emplace(ServiceDataId::RECORD_TRAFFIC, + SaveTrafficData::create(C.traffic_filename())); +} + +void MainData::deinitialize_service_datas() { service_datas_.clear(); } - explicit CBTunnelSendTo(writer_t out) : CBTunnel(std::move(out)) {} +GameCommandData::GameCommandData() = default; - packet_buffer buffer() override { - return {reinterpret_cast(cpu_state->ecx), - static_cast(cpu_state->eax), 0U, 1U}; +void GameCommandData::put_work(work_t fn) { work.push(fn); } + +void GameCommandData::consume_work() { + auto items = work.pop(0, 0.0s); + for (const auto& it : items) { + it(); } -}; +} -struct CBDebugPrint final : public MyCB { - static constexpr char key_target[] = "cb_debug_print"; - static constexpr char key_name[] = "debug_print"; +GameCommandData* GameCommandData::get() { + return get_service_data(); +} - CBDebugPrint() = default; +std::unique_ptr GameCommandData::create() { + return std::make_unique(); +} - // TODO(shmocz): store debug messages in record file - void exec() override { - if (configuration()->debug_log()) { - char buf[1024]; - std::memset(buf, 'F', sizeof(buf)); - abi()->sprintf(reinterpret_cast(&buf), cpu_state->esp + 0x4); - fmt::print(stderr, "({}) {}", serialize::read_obj(cpu_state->esp), - buf); - } +MainData* MainData::instance_ = nullptr; +std::recursive_mutex MainData::lock_; + +MainData::MainData() : data_(std::make_unique()) {} + +MainData& MainData::get() { + if (instance_ == nullptr) { + instance_ = new MainData(); } -}; + return *instance_; +} + +GameDataYR* MainData::data() { return data_.get(); } -ra2yrcpp::hooks_yr::GameDataYR* ra2yrcpp::hooks_yr::get_data( - ra2yrcpp::InstrumentationService* I) { - return ensure_storage_value(I, "game_data"); +void MainData::lock() { lock_.lock(); } + +void MainData::unlock() { lock_.unlock(); } + +std::map>& +MainData::service_datas() { + return service_datas_; } -// TODO(shmocz): ensure thread safety -void ra2yrcpp::hooks_yr::init_callbacks(ra2yrcpp::hooks_yr::GameDataYR* D) { - if (D->callbacks_initialized) { - return; - } - auto t = std::to_string(static_cast( - std::chrono::high_resolution_clock::now().time_since_epoch().count())); - auto f = [D](std::unique_ptr c) { - D->callbacks.try_emplace(c->name(), std::move(c)); - }; +util::acquire_t MainData::acquire() { + return util::acquire(&get(), &lock_); +} - if (std::getenv("RA2YRCPP_RECORD_TRAFFIC") != nullptr) { - const std::string traffic_out = fmt::format("traffic.{}.pb.gz", t); - iprintf("record traffic to {}", traffic_out); +DEFINE_HOOK(0x7b3d6f, TunnelSendTo, 0x6) { + auto [mut, M] = MainData::acquire(); - auto out = std::make_shared( - std::make_shared( - traffic_out, std::ios_base::out | std::ios_base::binary), - true); - f(std::make_unique(out)); - f(std::make_unique(out)); - } - f(std::make_unique()); - f(std::make_unique()); + auto* C = SaveTrafficData::get(); + C->write_packet(C->send_buffer(reinterpret_cast(R))); + return 0U; +} + +DEFINE_HOOK(0x7b3f15, TunnelRecvFrom, 0x6) { + auto [mut, M] = MainData::acquire(); + + auto* C = SaveTrafficData::get(); + C->write_packet(C->recv_buffer(reinterpret_cast(R))); + return 0U; +} + +DEFINE_HOOK(0x643c62, UpdateLoadProgress, 0x6) { + (void)R; + auto [mut, M] = MainData::acquire(); - std::shared_ptr record_out = nullptr; + auto* B = ProgressScreenClass::Instance().PlayerProgresses; - if (std::getenv("RA2YRCPP_RECORD_PATH") != nullptr) { - const std::string record_path = std::getenv("RA2YRCPP_RECORD_PATH"); - D->cfg.set_record_filename(record_path); - iprintf("record state to {}", record_path); - record_out = std::make_shared( - record_path, std::ios_base::out | std::ios_base::binary); + auto* data = M->data(); + auto* sv = &data->sv; + auto* local_state = sv->mutable_load_state(); + if (local_state->load_progresses().empty()) { + for (auto i = 0U; i < (sizeof(*B) / sizeof(B)); i++) { + local_state->add_load_progresses(0.0); + } + } + for (int i = 0; i < local_state->load_progresses().size(); i++) { + local_state->set_load_progresses(i, B[i]); } - f(std::make_unique(record_out)); - f(std::make_unique()); - f(std::make_unique()); + sv->mutable_game_state()->set_stage( + ra2yrproto::ra2yr::LoadStage::STAGE_LOADING); + return 0U; +} + +DEFINE_HOOK(0x72dfb0, ExitGameLoop, 0x6) { + (void)R; + auto [mut, M] = MainData::acquire(); + + GameCommandData::get()->game_state()->set_stage( + ra2yrproto::ra2yr::STAGE_EXIT_GAME); + M->deinitialize_service_datas(); + + // Flush output in case the process is not terminated gracefully. + std::cerr << std::flush; + std::cout << std::flush; + return 0U; } -constexpr std::array gg_hooks = {{ - {0x55de4f, 7U, CBGameCommand::key_target}, // - {0x72dfb0, 6U, CBExitGameLoop::key_target}, // - {0x7b3d6f, 6U, CBTunnelSendTo::key_target}, // - {0x7b3f15, 6U, CBTunnelRecvFrom::key_target}, // - {0x643c62, 6U, CBUpdateLoadProgress::key_target}, // - {0x4068e0, 6U, CBDebugPrint::key_target}, -}}; - -std::vector ra2yrcpp::hooks_yr::get_hooks() { - return std::vector(gg_hooks.begin(), gg_hooks.end()); +DEFINE_HOOK(0x55de4f, GameLoopBegin, 0x7) { + (void)R; + auto [mut, M] = MainData::acquire(); + + // Save state + StateSave::get()->execute(); + + // If in single-step mode, release storage lock and wait for game to be + // unlocked. + auto* D = M->data(); + if (D->cfg.single_step()) { + M->unlock(); + D->game_paused.store(true); + D->game_paused.wait(false); + M->lock(); + } + + GameCommandData::get()->consume_work(); + + return 0U; +}; + +DEFINE_HOOK(0x7cd84d, ExeRun, 0x9) { + (void)R; + auto [mut, M] = MainData::acquire(); + M->initialize_service_datas(); + is_context::RA2YRCPP::get()->start_service(); + return 0U; } diff --git a/src/hooks_yr.hpp b/src/hooks_yr.hpp index b1f8bb5..9126860 100644 --- a/src/hooks_yr.hpp +++ b/src/hooks_yr.hpp @@ -5,7 +5,6 @@ #include "async_queue.hpp" #include "command/is_command.hpp" -#include "instrumentation_service.hpp" #include "ra2/abi.hpp" #include "ra2/state_context.hpp" #include "types.h" @@ -13,12 +12,12 @@ #include -#include +#include + #include #include #include -#include -#include +#include namespace util_command { template @@ -27,13 +26,23 @@ struct ISCommand; namespace ra2yrcpp::hooks_yr { -namespace { -using namespace std::chrono_literals; -} - using gpb::RepeatedPtrField; -using cb_map_t = std::map>; +// General purpose data container to hold resources that need to be freed at +// game exit. +class ServiceData { + public: + ServiceData() = default; + virtual ~ServiceData() = default; +}; + +enum class ServiceDataId : u32 { + GAME_COMMAND = 0U, + RECORD_TRAFFIC = 1U, + STATE_SAVE = 2U +}; + +// Should this be a singleton? struct GameDataYR { GameDataYR(); @@ -41,105 +50,77 @@ struct GameDataYR { ra2yrproto::ra2yr::StorageValue sv; ra2yrproto::commands::Configuration cfg; std::unique_ptr ctx{nullptr}; - cb_map_t callbacks; - bool callbacks_initialized{false}; util::AtomicVariable game_paused{false}; }; -struct CBYR : public ra2yrcpp::ISCallback { - using tc_t = RepeatedPtrField; - - GameDataYR* data_{nullptr}; +/// Singleton +class MainData { + public: + void initialize_service_datas(); + void deinitialize_service_datas(); + GameDataYR* data(); + static MainData& get(); + static util::acquire_t acquire(); + static void lock(); + static void unlock(); + std::map>& service_datas(); + void update_configuration(const ra2yrproto::commands::Configuration& C); + + private: + static MainData* instance_; + static std::recursive_mutex lock_; + std::unique_ptr data_; + std::map> service_datas_; + MainData(); + ~MainData(); +}; - CBYR(); +class GameDataInterface : public ServiceData { + public: + using tc_t = RepeatedPtrField; ra2::abi::ABIGameMD* abi(); ra2yrproto::commands::Configuration* configuration(); - void do_call() override; - virtual void exec() = 0; ra2yrproto::ra2yr::GameState* game_state(); - auto* prerequisite_groups(); + ra2yrproto::ra2yr::PrerequisiteGroups* prerequisite_groups(); tc_t* type_classes(); ra2::StateContext* get_state_context(); GameDataYR* data(); -}; + /// Update the underlying configuration. Record/traffic paths are determined + /// at initialization and will be ignored. -ra2yrcpp::hooks_yr::GameDataYR* get_data(ra2yrcpp::InstrumentationService* I); - -/// Get all currently active callback objects. -cb_map_t* get_callbacks(ra2yrcpp::InstrumentationService* I); - -template -struct MyCB : public B { - std::string name() override { return D::key_name; } - - std::string target() override { return D::key_target; } - - static D* get(ra2yrcpp::InstrumentationService* I) { - return reinterpret_cast(get_data(I)->callbacks.at(D::key_name).get()); - } + private: + // NB. cyclic dependency + GameDataYR* data_{nullptr}; }; -// TODO(shmocz): reduce calls to this -template -T* ensure_storage_value(ra2yrcpp::InstrumentationService* I, std::string key, - ArgsT... args) { - if (I->storage().find(key) == I->storage().end()) { - I->store_value(key, args...); - } - return static_cast(I->storage().at(key).get()); -} - -void init_callbacks(ra2yrcpp::hooks_yr::GameDataYR* D); +class GameCommandData : public GameDataInterface { + public: + using work_t = std::function; + static constexpr auto id = ServiceDataId::GAME_COMMAND; -struct work_item { - CBYR* cb; - command::iservice_cmd* cmd; - std::function fn; -}; + GameCommandData(); + void put_work(work_t fn); + void consume_work(); -struct CBGameCommand final : public MyCB { - static constexpr char key_name[] = "cb_game_command"; - static constexpr char key_target[] = "on_frame_update"; - using work_t = std::function; + // Get global instance + static GameCommandData* get(); + static std::unique_ptr create(); + private: async_queue::AsyncQueue work; - - CBGameCommand() = default; - - void put_work(work_t fn) { work.push(fn); } - - void exec() override { - // If in single-step mode, release storage lock and wait for game to be - // unlocked. - if (data()->cfg.single_step()) { - I->unlock_storage(); - data()->game_paused.store(true); - data()->game_paused.wait(false); - I->lock_storage(); - } - - auto items = work.pop(0, 0.0s); - for (const auto& it : items) { - it(); - } - } }; template -void get_gameloop_command(ra2yrcpp::command::ISCommand* Q, - std::function fn) { - auto* cb = ra2yrcpp::hooks_yr::CBGameCommand::get(Q->I()); +void get_gameloop_command(const ra2yrcpp::command::ISCommand* Q, + std::function fn) { + auto* ctx = GameCommandData::get(); auto* cmd = Q->c; - cmd->set_async_handler([cb, fn](auto*) { fn(cb); }); - cb->put_work([cmd]() { cmd->run_async_handler(); }); + cmd->set_async_handler([ctx, fn](auto*) { fn(ctx); }); + ctx->put_work([cmd]() { cmd->run_async_handler(); }); } -struct YRHook { - u32 address; - u32 size; - const char* name; -}; - -std::vector get_hooks(); +void create_all_hooks(); +void create_all_hooks(char* hooks_section, std::size_t section_size, + void* dll_handle); }; // namespace ra2yrcpp::hooks_yr diff --git a/src/instrumentation_service.cpp b/src/instrumentation_service.cpp index 37996ba..22dcb0e 100644 --- a/src/instrumentation_service.cpp +++ b/src/instrumentation_service.cpp @@ -5,6 +5,7 @@ #include "asio_utils.hpp" #include "command/command_manager.hpp" #include "config.hpp" +#include "hook.hpp" #include "logging.hpp" #include "protocol/helpers.hpp" #include "util_string.hpp" @@ -20,44 +21,8 @@ using namespace ra2yrcpp; -ISCallback::ISCallback() : I(nullptr) {} - -ISCallback::~ISCallback() {} - -void ISCallback::add_to_hook(hook::Hook* h, - ra2yrcpp::InstrumentationService* I) { - this->I = I; - // TODO(shmocz): avoid using wrapper - h->add_callback([this](hook::Hook* h, void* user_data, - X86Regs* state) { this->call(h, user_data, state); }, - nullptr, name()); -} - -std::vector -InstrumentationService::get_connection_threads() { - std::vector res; - res.push_back(io_service_tid_.get()); - return res; -} - -void InstrumentationService::create_hook(const std::string& name, - const std::uintptr_t target, - std::size_t code_length) { - std::unique_lock lk(mut_hooks_); - iprintf("name={},target={:#x},size_bytes={}", name, target, code_length); - if (hooks_.find(target) != hooks_.end()) { - throw std::runtime_error( - fmt::format("Can't overwrite existing hook (name={} address={})", name, - reinterpret_cast(target))); - } - auto tids = get_connection_threads(); - hooks_.try_emplace(target, target, code_length, name, tids, true); -} - cmd_manager_t& InstrumentationService::cmd_manager() { return cmd_manager_; } -hooks_t& InstrumentationService::hooks() { return hooks_; } - static ra2yrproto::TextResponse text_response(std::string message) { ra2yrproto::TextResponse E; E.mutable_message()->assign(message); @@ -244,16 +209,6 @@ void* InstrumentationService::get_value(std::string key, bool acquire) { return storage_.at(key).get(); } -storage_t& InstrumentationService::storage() { return storage_; } - -util::acquire_t InstrumentationService::aq_hooks() { - return util::acquire(&hooks_, &mut_hooks_); -} - -void InstrumentationService::lock_storage() { mut_storage_.lock(); } - -void InstrumentationService::unlock_storage() { mut_storage_.unlock(); } - util::acquire_t InstrumentationService::aq_storage() { return util::acquire(&storage_, &mut_storage_); diff --git a/src/instrumentation_service.hpp b/src/instrumentation_service.hpp index 7775a90..319629b 100644 --- a/src/instrumentation_service.hpp +++ b/src/instrumentation_service.hpp @@ -10,7 +10,6 @@ #include "utility/sync.hpp" #include "websocket_server.hpp" -#include #include #include @@ -20,7 +19,6 @@ #include #include #include -#include namespace ra2yrcpp { namespace asio_utils { @@ -33,23 +31,13 @@ namespace ra2yrcpp { // Forward declaration class InstrumentationService; -/// Hook callback that provides access to InstrumentationService. -struct ISCallback : public hook::Callback { - ISCallback(); - ~ISCallback() override; - /// Add this callback to the given hook and assigns pointer to IService. - void add_to_hook(hook::Hook* h, ra2yrcpp::InstrumentationService* I); - - ra2yrcpp::InstrumentationService* I; -}; - +// TODO(shmocz): Deprecate storage because it's largely unused. using storage_t = std::map>>; using ra2yrcpp::websocket_server::WebsocketServer; using cmd_t = ra2yrcpp::command::iservice_cmd; using cmd_manager_t = ra2yrcpp::command::CommandManager; using command_ptr_t = cmd_manager_t::command_ptr_t; -using hooks_t = std::map; using command_hdl_t = command_ptr_t::weak_type; class InstrumentationService { @@ -72,24 +60,9 @@ class InstrumentationService { std::function extra_init = nullptr); ~InstrumentationService(); - /// - /// Returns OS specific thread id's for all active client connections. Mostly - /// useful during hooking to not suspend the connection threads. - /// - std::vector get_connection_threads(); - /// Create hook to given memory location - /// @param name - /// @param target target memory address - /// @param code_length the amount of bytes to copy into target detour location - void create_hook(const std::string& name, std::uintptr_t target, - std::size_t code_length); cmd_manager_t& cmd_manager(); - hooks_t& hooks(); - util::acquire_t aq_hooks(); // TODO(shmocz): separate storage class util::acquire_t aq_storage(); - void lock_storage(); - void unlock_storage(); template void store_value(std::string key, Args&&... args) { @@ -104,7 +77,6 @@ class InstrumentationService { /// @return pointer to the storage object /// @exception std::out_of_range if value doesn't exist void* get_value(std::string key, bool acquire = true); - storage_t& storage(); const InstrumentationService::Options& opts() const; static ra2yrcpp::InstrumentationService* create( InstrumentationService::Options O, @@ -123,8 +95,6 @@ class InstrumentationService { Options opts_; std::function on_shutdown_; cmd_manager_t cmd_manager_; - hooks_t hooks_; - std::mutex mut_hooks_; storage_t storage_; std::recursive_mutex mut_storage_; std::unique_ptr io_service_; diff --git a/src/is_context.cpp b/src/is_context.cpp index 97b7691..f51d95d 100644 --- a/src/is_context.cpp +++ b/src/is_context.cpp @@ -1,10 +1,5 @@ #include "is_context.hpp" -#include "protocol/protocol.hpp" -#include "ra2yrproto/commands_builtin.pb.h" -#include "ra2yrproto/commands_yr.pb.h" -#include "ra2yrproto/core.pb.h" - #include "command/is_command.hpp" #include "commands_builtin.hpp" #include "commands_game.hpp" @@ -12,12 +7,11 @@ #include "config.hpp" #include "context.hpp" #include "dll_inject.hpp" -#include "hooks_yr.hpp" +#include "hook.hpp" #include "instrumentation_service.hpp" #include "logging.hpp" #include "process.hpp" #include "types.h" -#include "utility/sync.hpp" #include "utility/time.hpp" #include "win32/windows_utils.hpp" #include "x86.hpp" @@ -35,8 +29,6 @@ using namespace std::chrono_literals; using namespace is_context; using x86::bytes_to_stack; -namespace gpb = google::protobuf; - ProcAddrs is_context::get_procaddrs() { ProcAddrs A; A.p_LoadLibrary = windows_utils::get_proc_address("LoadLibraryA"); @@ -50,12 +42,14 @@ vecu8 is_context::vecu8cstr(std::string s) { return r; } +// TODO: Get rid of the "Context" thingy. static Context* make_is_ctx(Context* c, const ra2yrcpp::InstrumentationService::Options O) { auto* I = is_context::make_is(O, [c](auto* X) { (void)X; return c->on_signal(); }); + c->data() = reinterpret_cast(I); c->deleter() = [](Context* ctx) { delete reinterpret_cast(ctx->data()); @@ -129,15 +123,6 @@ void is_context::get_procaddr(Xbyak::CodeGenerator* c, void* m, c->ret(); } -static void handle_cmd_wait(ra2yrcpp::InstrumentationService* I, - const gpb::Message& cmd) { - auto CC = ra2yrcpp::create_command(cmd); - util::AtomicVariable done(false); - (void)ra2yrcpp::handle_cmd(I, 0U, &CC, true, - [&done](auto*) { done.store(true); }); - done.wait(true); -} - ra2yrcpp::InstrumentationService* is_context::make_is( ra2yrcpp::InstrumentationService::Options O, std::function on_shutdown) { @@ -154,24 +139,6 @@ ra2yrcpp::InstrumentationService* is_context::make_is( for (auto& [name, fn] : cmds) { t->cmd_manager().add_command(name, fn); } - - if (!t->opts().no_init_hooks) { - ra2yrproto::commands::CreateHooks C1; - - C1.set_no_suspend_threads(true); - for (const auto& Y : ra2yrcpp::hooks_yr::get_hooks()) { - auto* H = C1.add_hooks(); - H->set_address(Y.address); - H->set_name(Y.name); - H->set_code_length(Y.size); - } - - handle_cmd_wait(t, C1); - ra2yrproto::commands::CreateCallbacks C2; - handle_cmd_wait(t, C2); - } else { - iprintf("not creating hooks and callbacks"); - } }); return I; @@ -220,3 +187,63 @@ void* is_context::get_context( const ra2yrcpp::InstrumentationService::Options O) { return make_is_ctx(new is_context::Context(), O); } + +void RA2YRCPP::create_hook(hook::HookEntry h, hook::hook_fn f) { + iprintf("name={},target={:#x},size_bytes={}", h.name, h.address, h.size); + if (hooks_.find(h.address) != hooks_.end()) { + throw std::runtime_error( + fmt::format("Can't overwrite existing hook (name={} address={})", + h.name, reinterpret_cast(h.address))); + } + hooks_.try_emplace(h.address, h, f); +} + +void RA2YRCPP::create_all_hooks(char* hooks_section, std::size_t section_size, + void* dll_handle) { + // For each hook entry + const char* hooks_end = hooks_section + section_size; + for (char* p = hooks_section; p < hooks_end; p += sizeof(hook::HookEntry)) { + auto* H = reinterpret_cast(p); + // Get corresponding function + // std::string fn_name = "_" + std::string(H->hookName); + std::string fn_name = std::string(H->name); + auto* proc_address = reinterpret_cast( + windows_utils::get_proc_address(fn_name, dll_handle)); + if (proc_address == nullptr) { + throw std::runtime_error( + fmt::format("couldn't find hook function: {}", fn_name)); + } + // Patch target code + create_hook(*H, proc_address); + } +} + +void RA2YRCPP::create_all_hooks() { + auto P = process::get_current_process(); + void* dll = windows_utils::find_dll(cfg::DLL_NAME); + if (dll == nullptr) { + throw std::runtime_error("ra2yrcpp main DLL not loaded"); + } + + // Get syringe section + auto section = windows_utils::find_section(dll, ".syhks00"); + if (section.data == nullptr) { + throw std::runtime_error(".syhks00 section not found from DLL"); + } + + create_all_hooks(reinterpret_cast(section.data), section.length, dll); +} + +RA2YRCPP* RA2YRCPP::get() { + static RA2YRCPP* I = nullptr; + if (I == nullptr) { + I = new RA2YRCPP(); + } + return I; +} + +void RA2YRCPP::start_service() { + if (service_ == nullptr) { + service_ = is_context::make_is(o); + } +} diff --git a/src/is_context.hpp b/src/is_context.hpp index 2fbe104..e146429 100644 --- a/src/is_context.hpp +++ b/src/is_context.hpp @@ -5,9 +5,11 @@ #include +#include #include #include +#include #include namespace dll_inject { @@ -62,6 +64,20 @@ void inject_dll(unsigned pid, std::string path_dll, void* get_context(ra2yrcpp::InstrumentationService::Options O); +// Global ra2yrcpp instance +struct RA2YRCPP { + ra2yrcpp::InstrumentationService::Options o; + ra2yrcpp::InstrumentationService* service_; + std::map hooks_; + void create_all_hooks(); + void create_all_hooks(char* hooks_section, std::size_t section_size, + void* dll_handle); + void create_hook(hook::HookEntry h, hook::hook_fn f); + void start_service(); + // Get global instance + static RA2YRCPP* get(); +}; + const DLLLoader::Options default_options{ {0U, 0U}, cfg::DLL_NAME, cfg::INIT_NAME, cfg::MAX_CLIENTS, cfg::SERVER_PORT, false, false}; diff --git a/src/ra2/abi.cpp b/src/ra2/abi.cpp index fed4af2..5898d7c 100644 --- a/src/ra2/abi.cpp +++ b/src/ra2/abi.cpp @@ -18,11 +18,6 @@ Xbyak::CodeGenerator* ABIGameMD::find_codegen(u32 address) { codegen_store& ABIGameMD::code_generators() { return code_generators_; } -util::acquire_t -ABIGameMD::acquire_code_generators() { - return util::acquire(&code_generators_, &mut_code_generators_); -} - bool ABIGameMD::SelectObject(u32 address) { return ra2::abi::SelectObject::call(this, address); } diff --git a/src/ra2/abi.hpp b/src/ra2/abi.hpp index 3cb8dbf..e34286c 100644 --- a/src/ra2/abi.hpp +++ b/src/ra2/abi.hpp @@ -5,14 +5,12 @@ #include "utility/array_iterator.hpp" #include "utility/function_traits.hpp" #include "utility/serialize.hpp" -#include "utility/sync.hpp" #include #include #include -#include #include #include #include @@ -79,9 +77,6 @@ class ABIGameMD { codegen_store& code_generators(); - util::acquire_t - acquire_code_generators(); - template auto call(Args... args) { return T::call(this, args...); diff --git a/src/yrclient_dll.cpp b/src/yrclient_dll.cpp index 65f6a83..6dfc76f 100644 --- a/src/yrclient_dll.cpp +++ b/src/yrclient_dll.cpp @@ -10,6 +10,7 @@ static void* g_context = nullptr; +// TODO: Create placeholder class to get all env vars. void ra2yrcpp::initialize(unsigned int max_clients, unsigned int port, bool no_init_hooks) { static std::mutex g_lock; @@ -20,7 +21,13 @@ void ra2yrcpp::initialize(unsigned int max_clients, unsigned int port, {cfg::SERVER_ADDRESS, port, max_clients, (h != nullptr ? h : cfg::ALLOWED_HOSTS_REGEX)}, no_init_hooks}; - g_context = is_context::get_context(O); + if (!O.no_init_hooks) { + auto* I = is_context::RA2YRCPP::get(); + I->o = O; + I->create_all_hooks(); + } else { + g_context = is_context::get_context(O); + } } g_lock.unlock(); diff --git a/tests/test_hooks.cpp b/tests/test_hooks.cpp index 6c79eba..b1dcf4b 100644 --- a/tests/test_hooks.cpp +++ b/tests/test_hooks.cpp @@ -15,6 +15,8 @@ using namespace hook; using namespace std; +// FIXME: rewrite tests +#if 0 /// Multiplies two unsigned integers, and returns result in EAX struct ExampleProgram : Xbyak::CodeGenerator { ExampleProgram() { @@ -165,6 +167,7 @@ TEST(HookTest, BasicCallbackMultipleThreads) { t.join(); } } +#endif TEST(HookTest, CorrectBehaviorWhenThreadsInHook) {} diff --git a/tests/test_instrumentation_service.cpp b/tests/test_instrumentation_service.cpp index 8837331..25c0b87 100644 --- a/tests/test_instrumentation_service.cpp +++ b/tests/test_instrumentation_service.cpp @@ -92,42 +92,6 @@ class IServiceTest : public InstrumentationServiceTest { void init() override {} }; -TEST_F(IServiceTest, HookingGetSetWorks) { -#ifdef RA2YRCPP_64 - GTEST_SKIP(); -#endif - // store initial flag value - std::string key = "test_key"; - std::string flag1 = "0xdeadbeef"; - std::string flag2 = "0xbeefdead"; - auto value_eq = [&](std::string v) { - auto r = cs->run(GetValue::create({key, ""})); - ASSERT_EQ(r.value(), v); - }; - - auto h = HookableCommand::create({}); - (void)cs->run(StoreValue::create({key, flag1})); - auto res0 = cs->run(h); - ASSERT_NE(res0.address_test_function(), 0); - value_eq(flag1); - { - ra2yrproto::HookEntry E; - E.set_address(res0.address_test_function()); - E.set_name("test_hook"); - E.set_code_length(res0.code_size()); - std::vector V; - V.push_back(E); - auto res_ih_a = cs->run(CreateHooks::create({true, V})); - value_eq(flag1); - } - // install callback, which modifies the value (TODO: jit the callback) - auto ac = AddCallback::create( - {res0.address_test_function(), res0.address_test_callback()}); - (void)cs->run(ac); - (void)cs->run(h); - value_eq(flag2); -} - class NewCommandsTest : public InstrumentationServiceTest { protected: void init() override { diff --git a/tests/util_proto.hpp b/tests/util_proto.hpp index 2ec6844..1e36fff 100644 --- a/tests/util_proto.hpp +++ b/tests/util_proto.hpp @@ -4,7 +4,6 @@ #include "ra2yrproto/core.pb.h" #include -#include namespace ra2yrcpp::test_util { @@ -32,47 +31,4 @@ struct GetValue { } }; -struct HookableCommand { - u64 address_test_function; - u32 code_size; - u64 address_test_callback; - - static ra2yrproto::commands::HookableCommand create(HookableCommand c) { - ra2yrproto::commands::HookableCommand s; - s.set_address_test_callback(c.address_test_callback); - s.set_address_test_function(c.address_test_function); - s.set_code_size(c.code_size); - return s; - } -}; - -struct AddCallback { - u64 hook_address; - u64 callback_address; - - static ra2yrproto::commands::AddCallback create(AddCallback c) { - ra2yrproto::commands::AddCallback s; - s.set_hook_address(c.hook_address); - s.set_callback_address(c.callback_address); - return s; - } -}; - -struct CreateHooks { - bool no_suspend_threads; - std::vector hooks; - - static ra2yrproto::commands::CreateHooks create(CreateHooks c) { - ra2yrproto::commands::CreateHooks s; - s.set_no_suspend_threads(c.no_suspend_threads); - for (auto& ch : c.hooks) { - auto* h = s.add_hooks(); - h->set_address(ch.address()); - h->set_name(ch.name()); - h->set_code_length(ch.code_length()); - } - return s; - } -}; - }; // namespace ra2yrcpp::test_util