diff --git a/CMakeLists.txt b/CMakeLists.txt index c9516e521..b00dbe7bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,11 @@ if(NOT CMAKE_BUILD_TYPE) ) endif() + +# CJK +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8") +endif() # Language specs set(CMAKE_CXX_STANDARD 20) set(CMAKE_C_STANDARD 17) diff --git a/primedev/Northstar.cmake b/primedev/Northstar.cmake index 35383e69c..b849a0222 100644 --- a/primedev/Northstar.cmake +++ b/primedev/Northstar.cmake @@ -1,13 +1,21 @@ # NorthstarDLL - +set(OPENSSL_USE_STATIC_LIBS TRUE) +set(ENV{OPENSSL_ROOT_DIR} "C:/Program Files/OpenSSL-Win64") +set(OPENSSL_MSVC_STATIC_RT TRUE) find_package(minhook REQUIRED) find_package(libcurl REQUIRED) find_package(minizip REQUIRED) find_package(silver-bun REQUIRED) +find_package(nlohmann-json REQUIRED) +find_package(OpenSSL REQUIRED) +find_package(httplib COMPONENTS OpenSSL) add_library( NorthstarDLL SHARED "resources.rc" + "core/anticheat.cpp" + "core/anticheat.h" + "client/audio.cpp" "client/audio.h" "client/chatcommand.cpp" @@ -79,6 +87,7 @@ add_library( "logging/sourceconsole.h" "masterserver/masterserver.cpp" "masterserver/masterserver.h" + "masterserver/cabundle.h" "mods/autodownload/moddownloader.h" "mods/autodownload/moddownloader.cpp" "mods/compiled/kb_act.cpp" @@ -99,6 +108,7 @@ add_library( "plugins/plugins.h" "plugins/pluginmanager.h" "plugins/pluginmanager.cpp" + "scripts/clantag.cpp" "scripts/client/clientchathooks.cpp" "scripts/client/cursorposition.cpp" "scripts/client/scriptbrowserhooks.cpp" @@ -107,6 +117,16 @@ add_library( "scripts/client/scriptoriginauth.cpp" "scripts/client/scriptserverbrowser.cpp" "scripts/client/scriptservertoclientstringcommand.cpp" + "scripts/server/scriptuserinfo.cpp" + "scripts/scriptmasterservermessages.cpp" + "scripts/scriptmasterservermessages.h" + "scripts/scriptgamestate.cpp" + "scripts/scriptsvm.cpp" + + "scripts/scriptgamestate.h" + "scripts/scriptgamestate.cpp" + "scripts/scriptmatchmakingevents.h" + "scripts/scriptmatchmakingevents.cpp" "scripts/server/miscserverfixes.cpp" "scripts/server/miscserverscript.cpp" "scripts/server/scriptuserinfo.cpp" @@ -120,6 +140,8 @@ add_library( "server/auth/bansystem.h" "server/auth/serverauthentication.cpp" "server/auth/serverauthentication.h" + "server/svm.cpp" + "server/svm.h" "server/alltalk.cpp" "server/ai_helper.cpp" "server/ai_helper.h" @@ -130,7 +152,7 @@ add_library( "server/r2server.h" "server/serverchathooks.cpp" "server/serverchathooks.h" - "server/servernethooks.cpp" + "server/serverpresence.cpp" "server/serverpresence.h" "shared/exploit_fixes/exploitfixes.cpp" @@ -161,6 +183,10 @@ add_library( "util/version.h" "util/wininfo.cpp" "util/wininfo.h" + "util/base64.cpp" + "util/base64.h" + "util/dohworker.cpp" + "util/dohworker.h" "vscript/languages/squirrel_re/include/squirrel.h" "vscript/languages/squirrel_re/squirrel/sqarray.h" "vscript/languages/squirrel_re/squirrel/sqclosure.h" @@ -187,7 +213,11 @@ add_library( target_link_libraries( NorthstarDLL - PRIVATE minhook + PRIVATE nlohmann_json::nlohmann_json + OpenSSL::SSL + OpenSSL::Crypto + httplib::httplib + minhook libcurl minizip silver-bun @@ -212,6 +242,7 @@ target_compile_definitions( PRIVATE UNICODE _UNICODE CURL_STATICLIB + ) set_target_properties( diff --git a/primedev/RCa23572 b/primedev/RCa23572 new file mode 100644 index 000000000..f237ac026 Binary files /dev/null and b/primedev/RCa23572 differ diff --git a/primedev/RCb23572 b/primedev/RCb23572 new file mode 100644 index 000000000..f237ac026 Binary files /dev/null and b/primedev/RCb23572 differ diff --git a/primedev/cmake/Findnlohmann-json.cmake b/primedev/cmake/Findnlohmann-json.cmake new file mode 100644 index 000000000..bd0079f94 --- /dev/null +++ b/primedev/cmake/Findnlohmann-json.cmake @@ -0,0 +1,6 @@ +if(NOT nlohmann-json_FOUND) + check_init_submodule(${PROJECT_SOURCE_DIR}/primedev/thirdparty/nlohmann-json) + + add_subdirectory(${PROJECT_SOURCE_DIR}/primedev/thirdparty/nlohmann-json nlohmann-json) + set(nlohmann-json_FOUND 1) +endif() diff --git a/primedev/core/anticheat.cpp b/primedev/core/anticheat.cpp new file mode 100644 index 000000000..ffe2a48ae --- /dev/null +++ b/primedev/core/anticheat.cpp @@ -0,0 +1,50 @@ +#include "pch.h" +#include "hooks.h" +#include +#include "anticheat.h" +#include +#include +#include +#include +#include +#include +#include +#include + +AUTOHOOK_INIT() + +TempReadWrite::TempReadWrite(void* ptr) +{ + m_ptr = ptr; + MEMORY_BASIC_INFORMATION mbi; + VirtualQuery(m_ptr, &mbi, sizeof(mbi)); + VirtualProtect(mbi.BaseAddress, mbi.RegionSize, PAGE_EXECUTE_READWRITE, &mbi.Protect); + m_origProtection = mbi.Protect; +} + +TempReadWrite::~TempReadWrite() +{ + MEMORY_BASIC_INFORMATION mbi; + VirtualQuery(m_ptr, &mbi, sizeof(mbi)); + VirtualProtect(mbi.BaseAddress, mbi.RegionSize, m_origProtection, &mbi.Protect); +} + +void ClientAnticheatSystem::NoFindWindowHack(uintptr_t baseAddress) +{ + unsigned seed = time(0); + srand(seed); + char ObfChar[3]; + int ObfuscateNum = 100 + rand() % 899; + sprintf(ObfChar, "%d", ObfuscateNum); + std::cout << ObfuscateNum << std::endl; + char* ptr = ((char*)baseAddress + 0x607BD0); + TempReadWrite rw(ptr); + *(ptr + 14) = (char)ObfChar[0]; + *(ptr + 16) = (char)ObfChar[1]; + *(ptr + 18) = (char)ObfChar[2]; +} + +ON_DLL_LOAD("engine.dll", ACinit, (CModule module)) +{ + g_ClientAnticheatSystem.NoFindWindowHack(module.GetModuleBase()); +} diff --git a/primedev/core/anticheat.h b/primedev/core/anticheat.h new file mode 100644 index 000000000..51cd11b72 --- /dev/null +++ b/primedev/core/anticheat.h @@ -0,0 +1,23 @@ +#pragma once +#include +#include + +typedef unsigned long DWORD; + +class TempReadWrite +{ + private: + DWORD m_origProtection; + void* m_ptr; + + public: + TempReadWrite(void* ptr); + ~TempReadWrite(); +}; +class ClientAnticheatSystem +{ + public: + void NoFindWindowHack(uintptr_t baseAddress); +}; + +extern ClientAnticheatSystem g_ClientAnticheatSystem; diff --git a/primedev/engine/hoststate.cpp b/primedev/engine/hoststate.cpp index 4a4d909da..1ca324f4f 100644 --- a/primedev/engine/hoststate.cpp +++ b/primedev/engine/hoststate.cpp @@ -75,6 +75,8 @@ static void __fastcall h_CHostState__State_NewGame(CHostState* self) g_pServerPresence->SetPlaylist(R2::GetCurrentPlaylistName()); g_pServerPresence->SetPort(Cvar_hostport->GetInt()); + g_pServerAuthentication->StartPlayerAuthServer(); + g_pServerAuthentication->m_bNeedLocalAuthForNewgame = false; } diff --git a/primedev/masterserver/cabundle.h b/primedev/masterserver/cabundle.h new file mode 100644 index 000000000..c9bfd00f5 --- /dev/null +++ b/primedev/masterserver/cabundle.h @@ -0,0 +1,32 @@ +#pragma once +static const char cabundle[] = "-----BEGIN CERTIFICATE-----\n" + "MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw\n" + "TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh\n" + "cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4\n" + "WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu\n" + "ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY\n" + "MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc\n" + "h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+\n" + "0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U\n" + "A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW\n" + "T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH\n" + "B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC\n" + "B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv\n" + "KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn\n" + "OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn\n" + "jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw\n" + "qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI\n" + "rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV\n" + "HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq\n" + "hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL\n" + "ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ\n" + "3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK\n" + "NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5\n" + "ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur\n" + "TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC\n" + "jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc\n" + "oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq\n" + "4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA\n" + "mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d\n" + "emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=\n" + "-----END CERTIFICATE-----\n"; diff --git a/primedev/masterserver/masterserver.cpp b/primedev/masterserver/masterserver.cpp index 56ab608d2..24b99356c 100644 --- a/primedev/masterserver/masterserver.cpp +++ b/primedev/masterserver/masterserver.cpp @@ -5,27 +5,39 @@ #include "core/tier0.h" #include "core/vanilla.h" #include "engine/r2engine.h" +#include "client/r2client.h" #include "mods/modmanager.h" #include "shared/misccommands.h" #include "util/utils.h" #include "util/version.h" #include "server/auth/bansystem.h" #include "dedicated/dedicated.h" +#include "core/anticheat.h" +#include "util/dohworker.h" +#include "util/base64.h" +#include "scripts/scriptgamestate.h" #include "rapidjson/document.h" #include "rapidjson/stringbuffer.h" #include "rapidjson/writer.h" #include "rapidjson/error/en.h" +#include "cpp-httplib/httplib.h" +#include "curl/curl.h" +#include "cabundle.h" +#include "nlohmann/json.hpp" #include #include +#include using namespace std::chrono_literals; - MasterServerManager* g_pMasterServerManager; +ClientAnticheatSystem g_ClientAnticheatSystem; +ConVar* Cvar_ns_matchmaker_hostname; ConVar* Cvar_ns_masterserver_hostname; ConVar* Cvar_ns_curl_log_enable; +ConVar* Cvar_ns_server_reg_token; RemoteServerInfo::RemoteServerInfo( const char* newId, @@ -33,7 +45,7 @@ RemoteServerInfo::RemoteServerInfo( const char* newDescription, const char* newMap, const char* newPlaylist, - const char* newRegion, + int newGameState, int newPlayerCount, int newMaxPlayers, bool newRequiresPassword) @@ -49,29 +61,119 @@ RemoteServerInfo::RemoteServerInfo( strncpy_s((char*)map, sizeof(map), newMap, sizeof(map) - 1); strncpy_s((char*)playlist, sizeof(playlist), newPlaylist, sizeof(playlist) - 1); - strncpy((char*)region, newRegion, sizeof(region)); - region[sizeof(region) - 1] = 0; + gameState = newGameState; playerCount = newPlayerCount; maxPlayers = newMaxPlayers; } +inline std::string encode_query_param(const std::string& value) +{ + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (const auto c : value) + { + if (std::isalnum(static_cast(c)) || c == '-' || c == '.' || c == '_' || c == '~') + { + escaped << c; + } + else + { + escaped << std::uppercase; + escaped << '%' << std::setw(2) << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} void SetCommonHttpClientOptions(CURL* curl) { curl_easy_setopt(curl, CURLOPT_IPRESOLVE, CURL_IPRESOLVE_V4); curl_easy_setopt(curl, CURLOPT_VERBOSE, Cvar_ns_curl_log_enable->GetBool()); curl_easy_setopt(curl, CURLOPT_USERAGENT, &NSUserAgent); - // Timeout since the MS has fucky async functions without await, making curl hang due to a successful connection but no response for ~90 - // seconds. - curl_easy_setopt(curl, CURLOPT_TIMEOUT, 30L); - // curl_easy_setopt(curl, CURLOPT_STDERR, stdout); - if (CommandLine()->FindParm("-msinsecure")) // TODO: this check doesn't seem to work + if (!strstr(GetCommandLineA(), "-disabledoh")) { - curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 0L); - curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0L); + std::string masterserver_hostname = Cvar_ns_masterserver_hostname->GetString(); + const std::string doh_result = g_DohWorker->GetDOHResolve(masterserver_hostname); + if (g_DohWorker->m_bDohAvailable) + { + struct curl_slist* host = nullptr; + masterserver_hostname.erase(0, 8); + const std::string addr = masterserver_hostname + ":443:" + doh_result; + // spdlog::info(addr); + host = curl_slist_append(nullptr, addr.c_str()); + curl_easy_setopt(curl, CURLOPT_RESOLVE, host); + } + else + { + // spdlog::warn("[DOH] service is not available. falling back to DNS"); + } + } + else + { + // spdlog::warn("[DOH] service disabled"); } + // curl_easy_setopt(curl, CURLOPT_STDERR, stdout); + + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 0L); + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0L); } +httplib::Client SetupHttpClient() +{ + std::string ms_addr = Cvar_ns_masterserver_hostname->GetString(); + httplib::Client cli(ms_addr); + //, {"Accept-Encoding", "gzip"}, {"Content-Encoding", "gzip"} + cli.set_default_headers({{"User-Agent", NSUserAgent}}); + cli.set_compress(true); + cli.set_decompress(true); + cli.set_read_timeout(10, 0); + cli.set_write_timeout(10, 0); + // cli.enable_server_certificate_verification(false); + + cli.load_ca_cert_store(cabundle, sizeof(cabundle)); + // cli.set_ca_cert_path("ca-bundle.crt"); + if (!strstr(GetCommandLineA(), "-disabledoh")) + { + std::string doh_result = g_DohWorker->GetDOHResolve(ms_addr); + if (g_DohWorker->m_bDohAvailable) + { + cli.set_hostname_addr_map({{ms_addr, doh_result}}); + } + else + { + // spdlog::warn("[DOH] service is not available. falling back to DNS"); + } + } + return cli; +} +httplib::Client SetupMatchmakerHttpClient() +{ + std::string ms_addr = Cvar_ns_matchmaker_hostname->GetString(); + httplib::Client cli(ms_addr); + //, {"Accept-Encoding", "gzip"}, {"Content-Encoding", "gzip"} + cli.set_default_headers({{"User-Agent", NSUserAgent}}); + cli.set_compress(true); + cli.set_decompress(true); + cli.set_read_timeout(10, 0); + cli.set_write_timeout(10, 0); + if (!strstr(GetCommandLineA(), "-disabledoh")) + { + std::string doh_result = g_DohWorker->GetDOHResolve(ms_addr); + if (g_DohWorker->m_bDohAvailable) + { + cli.set_hostname_addr_map({{ms_addr, doh_result}}); + } + else + { + // spdlog::warn("[DOH] service is not available. falling back to DNS"); + } + } + return cli; +} void MasterServerManager::ClearServerList() { // this doesn't really do anything lol, probably isn't threadsafe @@ -84,9 +186,249 @@ void MasterServerManager::ClearServerList() size_t CurlWriteToStringBufferCallback(char* contents, size_t size, size_t nmemb, void* userp) { - ((std::string*)userp)->append((char*)contents, size * nmemb); + static_cast(userp)->append((char*)contents, size * nmemb); return size * nmemb; } +bool MasterServerManager::StartMatchmaking(MatchmakeInfo* status) +{ + // no need for a request thread here cuz we always call this function in a new thread. + // ps,aitdm,ttdm + CURL* curl = curl_easy_init(); + SetCommonHttpClientOptions(curl); + std::string read_buffer; + const std::string token = m_sOwnClientAuthToken; + const std::string local_uid = g_pLocalPlayerUserID; + char* local_uid_escaped = curl_easy_escape(curl, local_uid.c_str(), local_uid.length()); + char* token_escaped = curl_easy_escape(curl, token.c_str(), token.length()); + std::string query_fmt_str = "{}/join?id={}&token={}&aa_enabled={}"; + for (int i = 0; i < status->playlistList.size(); i++) + { + query_fmt_str.append(fmt::format("&playlist={}", status->playlistList[i])); + } + std::string query = + fmt::format((query_fmt_str), Cvar_ns_matchmaker_hostname->GetString(), local_uid_escaped, token_escaped, "true") + .c_str(); // TODO: add working AA selection + // spdlog::warn("{}", query); + // return false; + curl_easy_setopt(curl, CURLOPT_URL, query.c_str()); + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "POST"); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &read_buffer); + + const CURLcode result = curl_easy_perform(curl); + // spdlog::info("[Matchmaker] JOIN: Result:{},buffer:{}", result, read_buffer.c_str()); + if (result == CURLcode::CURLE_OK) + { + try + { + nlohmann::json resjson = nlohmann::json::parse(read_buffer); + if (resjson.at("success") == true) + { + // success + curl_easy_cleanup(curl); + return true; + } + else + { + // fucked + goto REQUEST_END_CLEANUP; + } + } + catch (nlohmann::json::parse_error& e) + { + spdlog::error("Failed communicating with matchmaker: encountered parse error \"{}\"", e.what()); + goto REQUEST_END_CLEANUP; + } + catch (nlohmann::json::out_of_range& e) + { + spdlog::error("Failed communicating with matchmaker: encountered data error \"{}\"", e.what()); + goto REQUEST_END_CLEANUP; + } + } + // we goto this instead of returning so we always hit this +REQUEST_END_CLEANUP: + + curl_easy_cleanup(curl); + return false; +} +bool MasterServerManager::CancelMatchmaking() +{ + // no need for a request thread here cuz we always call this function in a new thread. + // ps,aitdm,ttdm + CURL* curl = curl_easy_init(); + SetCommonHttpClientOptions(curl); + std::string read_buffer; + const std::string token = m_sOwnClientAuthToken; + char* token_escaped = curl_easy_escape(curl, token.c_str(), token.length()); + const std::string local_uid = g_pLocalPlayerUserID; + char* local_uid_escaped = curl_easy_escape(curl, local_uid.c_str(), local_uid.length()); + curl_easy_setopt( + curl, + CURLOPT_URL, + fmt::format("{}/quit?id={}&token={}", Cvar_ns_matchmaker_hostname->GetString(), local_uid_escaped, token_escaped).c_str()); + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "POST"); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &read_buffer); + + const CURLcode result = curl_easy_perform(curl); + + if (result == CURLcode::CURLE_OK) + { + // spdlog::info("[Matchmaker] Result:{},buffer:{}", result, read_buffer.c_str()); + try + { + nlohmann::json resjson = nlohmann::json::parse(read_buffer); + if (resjson.at("success") == true) + { + // success + curl_easy_cleanup(curl); + return true; + } + else + { + // fucked + goto REQUEST_END_CLEANUP; + } + } + catch (nlohmann::json::parse_error& e) + { + spdlog::error("Failed communicating with matchmaker: encountered parse error \"{}\"", e.what()); + goto REQUEST_END_CLEANUP; + } + catch (nlohmann::json::out_of_range& e) + { + spdlog::error("Failed communicating with matchmaker: encountered data error \"{}\"", e.what()); + goto REQUEST_END_CLEANUP; + } + + // spdlog::error("Failed reading player clantag"); + } + + // we goto this instead of returning so we always hit this +REQUEST_END_CLEANUP: + + curl_easy_cleanup(curl); + return false; +} + +bool MasterServerManager::UpdateMatchmakingStatus(MatchmakeInfo* status) +{ + // no need for a request thread here cuz we always call this function in a new thread. + // ps,aitdm,ttdm + CURL* curl = curl_easy_init(); + SetCommonHttpClientOptions(curl); + std::string read_buffer; + const std::string token = m_sOwnClientAuthToken; + const std::string local_uid = g_pLocalPlayerUserID; + char* local_uid_escaped = curl_easy_escape(curl, local_uid.c_str(), local_uid.length()); + char* token_escaped = curl_easy_escape(curl, token.c_str(), token.length()); + curl_easy_setopt( + curl, + CURLOPT_URL, + fmt::format("{}/state?id={}&token={}", Cvar_ns_matchmaker_hostname->GetString(), local_uid_escaped, token_escaped).c_str()); + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "GET"); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &read_buffer); + + const CURLcode result = curl_easy_perform(curl); + + if (result == CURLcode::CURLE_OK) + { + // spdlog::info("[Matchmaker] STATE: Result:{},buffer:{}", result, read_buffer.c_str()); + try + { + nlohmann::json server_response = nlohmann::json::parse(read_buffer); + if (server_response.at("success") == true) + { + std::string state_type = server_response.at("state"); + + if (!strcmp(state_type.c_str(), "#MATCHMAKING_QUEUED")) + { + status->etaSeconds = ""; + status->status = state_type; + // spdlog::info("[Matchmaker] MATCHMAKING_QUEUED"); + curl_easy_cleanup(curl); + return true; + } + if (!strcmp(state_type.c_str(), "#MATCHMAKING_ALLOCATING_SERVER")) + { + // spdlog::info("[Matchmaker] MATCHMAKING_ALLOCATING_SERVER"); + status->status = state_type; + curl_easy_cleanup(curl); + return true; + } + if (!strcmp(state_type.c_str(), "#MATCHMAKING_MATCH_CONNECTING")) + { + status->status = state_type; + status->serverId = server_response.at("id"); + status->serverReady = true; + curl_easy_cleanup(curl); + return true; + } + } + else + { + goto REQUEST_END_CLEANUP; + } + } + catch (nlohmann::json::parse_error& e) + { + spdlog::error("Failed communicating with matchmaker: encountered parse error \"{}\"", e.what()); + goto REQUEST_END_CLEANUP; + } + catch (nlohmann::json::out_of_range& e) + { + spdlog::error("Failed communicating with matchmaker: encountered data error \"{}\"", e.what()); + goto REQUEST_END_CLEANUP; + } + + spdlog::error("Failed reading matchmaking status"); + } + + // we goto this instead of returning so we always hit this +REQUEST_END_CLEANUP: + + curl_easy_cleanup(curl); + return false; +} + +bool MasterServerManager::SetLocalPlayerClanTag(const std::string clantag) +{ + + httplib::Client cli = SetupHttpClient(); + const std::string querystring = fmt::format( + "/client/clantag?clantag={}&id={}&token={}", + encode_query_param(clantag), + encode_query_param(g_pLocalPlayerUserID), + encode_query_param(m_sOwnClientAuthToken)); + auto res = cli.Post(querystring); + if (res && res->status == 200) + { + try + { + nlohmann::json setclantag_json = nlohmann::json::parse(res->body); + if (setclantag_json.at("success")) + { + return true; + } + else + { + return false; + } + } + catch (nlohmann::json::parse_error& e) + { + spdlog::error("Failed setting local player clantag: \"{}\"", e.what()); + } + catch (nlohmann::json::out_of_range& e) + { + spdlog::error("Failed setting local player clantag: \"{}\"", e.what()); + } + + return false; + } + return false; +} void MasterServerManager::AuthenticateOriginWithMasterServer(const char* uid, const char* originToken) { @@ -95,96 +437,60 @@ void MasterServerManager::AuthenticateOriginWithMasterServer(const char* uid, co // do this here so it's instantly set m_bOriginAuthWithMasterServerInProgress = true; - std::string uidStr(uid); - std::string tokenStr(originToken); - - m_bOriginAuthWithMasterServerSuccessful = false; - m_sOriginAuthWithMasterServerErrorCode = ""; - m_sOriginAuthWithMasterServerErrorMessage = ""; + std::string uid_str(uid); + std::string token_str(originToken); - std::thread requestThread( - [this, uidStr, tokenStr]() + std::thread request_thread( + [this, uid_str, token_str]() { - spdlog::info("Trying to authenticate with northstar masterserver for user {}", uidStr); - - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - std::string readBuffer; - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format("{}/client/origin_auth?id={}&token={}", Cvar_ns_masterserver_hostname->GetString(), uidStr, tokenStr).c_str()); - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "GET"); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - - CURLcode result = curl_easy_perform(curl); - ScopeGuard cleanup( - [&] - { - m_bOriginAuthWithMasterServerInProgress = false; - m_bOriginAuthWithMasterServerDone = true; - curl_easy_cleanup(curl); - }); + spdlog::info("Trying to authenticate with northstar masterserver for user {}", uid_str); - if (result == CURLcode::CURLE_OK) + httplib::Client cli = SetupHttpClient(); + const std::string endpoint = + fmt::format("/client/origin_auth?id={}&token={}", encode_query_param(uid_str), encode_query_param(token_str)); + if (auto res = cli.Get(endpoint)) { m_bSuccessfullyConnected = true; - - rapidjson_document originAuthInfo; - originAuthInfo.Parse(readBuffer.c_str()); - - if (originAuthInfo.HasParseError()) + try { - spdlog::error( - "Failed reading origin auth info response: encountered parse error \"{}\"", - rapidjson::GetParseError_En(originAuthInfo.GetParseError())); - return; - } + nlohmann::json originAuthInfo = nlohmann::json::parse(res->body); + if (originAuthInfo["success"]) + { + m_sOwnClientAuthToken = originAuthInfo.at("token"); + m_bOriginAuthWithMasterServerSuccess = true; - if (!originAuthInfo.IsObject() || !originAuthInfo.HasMember("success")) - { - spdlog::error("Failed reading origin auth info response: malformed response object {}", readBuffer); - return; + spdlog::info("Northstar origin authentication completed successfully!"); + } + else + { + m_sAuthFailureReason = originAuthInfo.at("error").at("enum"); + m_sAuthFailureMessage = originAuthInfo.at("error").at("msg"); + spdlog::error("Northstar origin authentication failed: {}", m_sAuthFailureMessage); + } } - - if (originAuthInfo["success"].IsTrue() && originAuthInfo.HasMember("token") && originAuthInfo["token"].IsString()) + catch (nlohmann::json::parse_error& e) { - strncpy_s( - m_sOwnClientAuthToken, - sizeof(m_sOwnClientAuthToken), - originAuthInfo["token"].GetString(), - sizeof(m_sOwnClientAuthToken) - 1); - spdlog::info("Northstar origin authentication completed successfully!"); - m_bOriginAuthWithMasterServerSuccessful = true; + spdlog::error("Failed reading origin auth info response: encountered parse error \"{}\"", e.what()); } - else + catch (nlohmann::json::out_of_range& e) { - spdlog::error("Northstar origin authentication failed"); - - if (originAuthInfo.HasMember("error") && originAuthInfo["error"].IsObject()) - { - - if (originAuthInfo["error"].HasMember("enum") && originAuthInfo["error"]["enum"].IsString()) - { - m_sOriginAuthWithMasterServerErrorCode = originAuthInfo["error"]["enum"].GetString(); - } - - if (originAuthInfo["error"].HasMember("msg") && originAuthInfo["error"]["msg"].IsString()) - { - m_sOriginAuthWithMasterServerErrorMessage = originAuthInfo["error"]["msg"].GetString(); - } - } + spdlog::error("Failed reading origin auth info response: encountered data error \"{}\"", e.what()); } } else { - spdlog::error("Failed performing northstar origin auth: error {}", curl_easy_strerror(result)); + const auto err = res.error(); + m_sAuthFailureReason = std::string("ERROR_NO_CONNECTION"); + m_sAuthFailureMessage = fmt::format("与主服务器进行初始化通信时出现错误:{}", httplib::to_string(err)); + spdlog::error("Failed performing northstar origin auth: {}", httplib::to_string(err)); m_bSuccessfullyConnected = false; } + + m_bOriginAuthWithMasterServerInProgress = false; + m_bOriginAuthWithMasterServerDone = true; }); - requestThread.detach(); + request_thread.detach(); } void MasterServerManager::RequestServerList() @@ -192,7 +498,7 @@ void MasterServerManager::RequestServerList() // do this here so it's instantly set on call for scripts m_bScriptRequestingServerList = true; - std::thread requestThread( + std::thread request_thread( [this]() { // make sure we never have 2 threads writing at once @@ -203,825 +509,429 @@ void MasterServerManager::RequestServerList() m_bRequestingServerList = true; m_bScriptRequestingServerList = true; - spdlog::info("Requesting server list from {}", Cvar_ns_masterserver_hostname->GetString()); + httplib::Client cli = SetupHttpClient(); - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - - std::string readBuffer; - curl_easy_setopt(curl, CURLOPT_URL, fmt::format("{}/client/servers", Cvar_ns_masterserver_hostname->GetString()).c_str()); - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "GET"); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - - CURLcode result = curl_easy_perform(curl); - ScopeGuard cleanup( - [&] - { - m_bRequestingServerList = false; - m_bScriptRequestingServerList = false; - curl_easy_cleanup(curl); - }); - - if (result == CURLcode::CURLE_OK) + if (auto res = cli.Get("/client/servers")) { m_bSuccessfullyConnected = true; - - rapidjson_document serverInfoJson; - serverInfoJson.Parse(readBuffer.c_str()); - - if (serverInfoJson.HasParseError()) + try { - spdlog::error( - "Failed reading masterserver response: encountered parse error \"{}\"", - rapidjson::GetParseError_En(serverInfoJson.GetParseError())); - return; - } + nlohmann::json server_list_json = nlohmann::json::parse(res->body); - if (serverInfoJson.IsObject() && serverInfoJson.HasMember("error")) - { - spdlog::error("Failed reading masterserver response: got fastify error response"); - spdlog::error(readBuffer); - return; - } + - if (!serverInfoJson.IsArray()) - { - spdlog::error("Failed reading masterserver response: root object is not an array"); - return; - } - - rapidjson::GenericArray serverArray = serverInfoJson.GetArray(); - - spdlog::info("Got {} servers", serverArray.Size()); - - for (auto& serverObj : serverArray) - { - if (!serverObj.IsObject()) + for (auto& server_obj : server_list_json) { - spdlog::error("Failed reading masterserver response: member of server array is not an object"); - return; - } + + char id[33]; + strncpy_s(id, 33, server_obj.at("id").get().c_str(), 33); - // todo: verify json props are fine before adding to m_remoteServers - if (!serverObj.HasMember("id") || !serverObj["id"].IsString() || !serverObj.HasMember("name") || - !serverObj["name"].IsString() || !serverObj.HasMember("description") || !serverObj["description"].IsString() || - !serverObj.HasMember("map") || !serverObj["map"].IsString() || !serverObj.HasMember("playlist") || - !serverObj["playlist"].IsString() || !serverObj.HasMember("playerCount") || !serverObj["playerCount"].IsNumber() || - !serverObj.HasMember("maxPlayers") || !serverObj["maxPlayers"].IsNumber() || !serverObj.HasMember("hasPassword") || - !serverObj["hasPassword"].IsBool() || !serverObj.HasMember("modInfo") || !serverObj["modInfo"].HasMember("Mods") || - !serverObj["modInfo"]["Mods"].IsArray()) - { - spdlog::error("Failed reading masterserver response: malformed server object"); - continue; - }; - const char* id = serverObj["id"].GetString(); + RemoteServerInfo* newServer = nullptr; + bool createNewServerInfo = true; + for (RemoteServerInfo& server : m_vRemoteServers) + { + if (!strncmp((const char*)server.id, id, 32)) + { + server = RemoteServerInfo( + id, + server_obj.at("name").get().c_str(), + server_obj.at("description").get().c_str(), + server_obj.at("map").get().c_str(), + server_obj.at("playlist").get().c_str(), + server_obj.at("gameState").get(), + server_obj.at("playerCount").get(), + server_obj.at("maxPlayers").get(), + server_obj.at("hasPassword").get()); + newServer = &server; + createNewServerInfo = false; + break; + } - RemoteServerInfo* newServer = nullptr; + } - bool createNewServerInfo = true; - for (RemoteServerInfo& server : m_vRemoteServers) - { - // if server already exists, update info rather than adding to it - if (!strncmp((const char*)server.id, id, 32)) + if (createNewServerInfo) { - server = RemoteServerInfo( + + newServer = &m_vRemoteServers.emplace_back( id, - serverObj["name"].GetString(), - serverObj["description"].GetString(), - serverObj["map"].GetString(), - serverObj["playlist"].GetString(), - (serverObj.HasMember("region") && serverObj["region"].IsString()) ? serverObj["region"].GetString() : "", - serverObj["playerCount"].GetInt(), - serverObj["maxPlayers"].GetInt(), - serverObj["hasPassword"].IsTrue()); - newServer = &server; - createNewServerInfo = false; - break; + server_obj.at("name").get().c_str(), + server_obj.at("description").get().c_str(), + server_obj.at("map").get().c_str(), + server_obj.at("playlist").get().c_str(), + server_obj.at("gameState").get(), + server_obj.at("playerCount").get(), + server_obj.at("maxPlayers").get(), + server_obj.at("hasPassword").get()); } - } + newServer->requiredMods.clear(); - // server didn't exist - if (createNewServerInfo) - newServer = &m_vRemoteServers.emplace_back( - id, - serverObj["name"].GetString(), - serverObj["description"].GetString(), - serverObj["map"].GetString(), - serverObj["playlist"].GetString(), - (serverObj.HasMember("region") && serverObj["region"].IsString()) ? serverObj["region"].GetString() : "", - serverObj["playerCount"].GetInt(), - serverObj["maxPlayers"].GetInt(), - serverObj["hasPassword"].IsTrue()); - - newServer->requiredMods.clear(); - for (auto& requiredMod : serverObj["modInfo"]["Mods"].GetArray()) - { - RemoteModInfo modInfo; - - if (!requiredMod.HasMember("RequiredOnClient") || !requiredMod["RequiredOnClient"].IsTrue()) - continue; - - if (!requiredMod.HasMember("Name") || !requiredMod["Name"].IsString()) - continue; - modInfo.Name = requiredMod["Name"].GetString(); - if (!requiredMod.HasMember("Version") || !requiredMod["Version"].IsString()) - continue; - modInfo.Version = requiredMod["Version"].GetString(); + for (auto& mod : server_obj.at("modInfo").at("Mods")) + { + if (mod.at("RequiredOnClient")) + { + RemoteModInfo mod_info; + + mod_info.Name = mod.at("Name"); + mod_info.Version = mod.at("Version"); + newServer->requiredMods.push_back(mod_info); + } + } - newServer->requiredMods.push_back(modInfo); } - // Can probably re-enable this later with a -verbose flag, but slows down loading of the server browser quite a bit as - // is - // spdlog::info( - // "Server {} on map {} with playlist {} has {}/{} players", serverObj["name"].GetString(), - // serverObj["map"].GetString(), serverObj["playlist"].GetString(), serverObj["playerCount"].GetInt(), - // serverObj["maxPlayers"].GetInt()); - } - std::sort( - m_vRemoteServers.begin(), - m_vRemoteServers.end(), - [](RemoteServerInfo& a, RemoteServerInfo& b) { return a.playerCount > b.playerCount; }); + std::ranges::sort( + m_vRemoteServers.begin(), + m_vRemoteServers.end(), + [](const RemoteServerInfo& a, const RemoteServerInfo& b) { return a.playerCount > b.playerCount; }); + } + catch (nlohmann::json::parse_error& e) + { + spdlog::error("Failed reading origin auth info response: encountered parse error \"{}\"", e.what()); + } + catch (nlohmann::json::out_of_range& e) + { + spdlog::error("Failed reading origin auth info response: encountered data error \"{}\"", e.what()); + } } else { - spdlog::error("Failed requesting servers: error {}", curl_easy_strerror(result)); + auto err = res.error(); + spdlog::error("Failed requesting servers: error {}", httplib::to_string(err)); m_bSuccessfullyConnected = false; } + m_bRequestingServerList = false; + m_bScriptRequestingServerList = false; }); - requestThread.detach(); + request_thread.detach(); } void MasterServerManager::RequestMainMenuPromos() { m_bHasMainMenuPromoData = false; - std::thread requestThread( + std::thread request_thread( [this]() { while (m_bOriginAuthWithMasterServerInProgress || !m_bOriginAuthWithMasterServerDone) Sleep(500); - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - - std::string readBuffer; - curl_easy_setopt( - curl, CURLOPT_URL, fmt::format("{}/client/mainmenupromos", Cvar_ns_masterserver_hostname->GetString()).c_str()); - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "GET"); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - - CURLcode result = curl_easy_perform(curl); - ScopeGuard cleanup([&] { curl_easy_cleanup(curl); }); - - if (result == CURLcode::CURLE_OK) + httplib::Client cli = SetupHttpClient(); + if (auto res = cli.Get("/client/mainmenupromos")) { m_bSuccessfullyConnected = true; + try + { + nlohmann::json mainMenuPromoJson = nlohmann::json::parse(res->body); + m_sMainMenuPromoData.newInfoTitle1 = mainMenuPromoJson.at("newInfo").at("Title1"); + m_sMainMenuPromoData.newInfoTitle2 = mainMenuPromoJson.at("newInfo").at("Title2"); + m_sMainMenuPromoData.newInfoTitle3 = mainMenuPromoJson.at("newInfo").at("Title3"); - rapidjson_document mainMenuPromoJson; - mainMenuPromoJson.Parse(readBuffer.c_str()); + m_sMainMenuPromoData.largeButtonTitle = mainMenuPromoJson.at("largeButton").at("Title"); + m_sMainMenuPromoData.largeButtonText = mainMenuPromoJson.at("largeButton").at("Text"); + m_sMainMenuPromoData.largeButtonUrl = mainMenuPromoJson.at("largeButton").at("Url"); + m_sMainMenuPromoData.largeButtonImageIndex = mainMenuPromoJson.at("largeButton").at("ImageIndex"); - if (mainMenuPromoJson.HasParseError()) - { - spdlog::error( - "Failed reading masterserver main menu promos response: encountered parse error \"{}\"", - rapidjson::GetParseError_En(mainMenuPromoJson.GetParseError())); - return; - } + m_sMainMenuPromoData.smallButton1Title = mainMenuPromoJson.at("smallButton1").at("Title"); + m_sMainMenuPromoData.smallButton1Url = mainMenuPromoJson.at("smallButton1").at("Url"); + m_sMainMenuPromoData.smallButton1ImageIndex = mainMenuPromoJson.at("smallButton1").at("ImageIndex"); - if (!mainMenuPromoJson.IsObject()) - { - spdlog::error("Failed reading masterserver main menu promos response: root object is not an object"); - return; - } + m_sMainMenuPromoData.smallButton2Title = mainMenuPromoJson.at("smallButton2").at("Title"); + m_sMainMenuPromoData.smallButton2Url = mainMenuPromoJson.at("smallButton2").at("Url"); + m_sMainMenuPromoData.smallButton2ImageIndex = mainMenuPromoJson.at("smallButton2").at("ImageIndex"); - if (mainMenuPromoJson.HasMember("error")) + m_bHasMainMenuPromoData = true; + } + catch (nlohmann::json::parse_error& e) { - spdlog::error("Failed reading masterserver response: got fastify error response"); - spdlog::error(readBuffer); - return; + spdlog::error("Failed reading masterserver main menu promos response: encountered parse error \"{}\"", e.what()); } - - if (!mainMenuPromoJson.HasMember("newInfo") || !mainMenuPromoJson["newInfo"].IsObject() || - !mainMenuPromoJson["newInfo"].HasMember("Title1") || !mainMenuPromoJson["newInfo"]["Title1"].IsString() || - !mainMenuPromoJson["newInfo"].HasMember("Title2") || !mainMenuPromoJson["newInfo"]["Title2"].IsString() || - !mainMenuPromoJson["newInfo"].HasMember("Title3") || !mainMenuPromoJson["newInfo"]["Title3"].IsString() || - - !mainMenuPromoJson.HasMember("largeButton") || !mainMenuPromoJson["largeButton"].IsObject() || - !mainMenuPromoJson["largeButton"].HasMember("Title") || !mainMenuPromoJson["largeButton"]["Title"].IsString() || - !mainMenuPromoJson["largeButton"].HasMember("Text") || !mainMenuPromoJson["largeButton"]["Text"].IsString() || - !mainMenuPromoJson["largeButton"].HasMember("Url") || !mainMenuPromoJson["largeButton"]["Url"].IsString() || - !mainMenuPromoJson["largeButton"].HasMember("ImageIndex") || - !mainMenuPromoJson["largeButton"]["ImageIndex"].IsNumber() || - - !mainMenuPromoJson.HasMember("smallButton1") || !mainMenuPromoJson["smallButton1"].IsObject() || - !mainMenuPromoJson["smallButton1"].HasMember("Title") || !mainMenuPromoJson["smallButton1"]["Title"].IsString() || - !mainMenuPromoJson["smallButton1"].HasMember("Url") || !mainMenuPromoJson["smallButton1"]["Url"].IsString() || - !mainMenuPromoJson["smallButton1"].HasMember("ImageIndex") || - !mainMenuPromoJson["smallButton1"]["ImageIndex"].IsNumber() || - - !mainMenuPromoJson.HasMember("smallButton2") || !mainMenuPromoJson["smallButton2"].IsObject() || - !mainMenuPromoJson["smallButton2"].HasMember("Title") || !mainMenuPromoJson["smallButton2"]["Title"].IsString() || - !mainMenuPromoJson["smallButton2"].HasMember("Url") || !mainMenuPromoJson["smallButton2"]["Url"].IsString() || - !mainMenuPromoJson["smallButton2"].HasMember("ImageIndex") || - !mainMenuPromoJson["smallButton2"]["ImageIndex"].IsNumber()) + catch (nlohmann::json::out_of_range& e) { - spdlog::error("Failed reading masterserver main menu promos response: malformed json object"); - return; + spdlog::error("Failed reading masterserver main menu promos response: encountered data error \"{}\"", e.what()); } - - m_sMainMenuPromoData.newInfoTitle1 = mainMenuPromoJson["newInfo"]["Title1"].GetString(); - m_sMainMenuPromoData.newInfoTitle2 = mainMenuPromoJson["newInfo"]["Title2"].GetString(); - m_sMainMenuPromoData.newInfoTitle3 = mainMenuPromoJson["newInfo"]["Title3"].GetString(); - - m_sMainMenuPromoData.largeButtonTitle = mainMenuPromoJson["largeButton"]["Title"].GetString(); - m_sMainMenuPromoData.largeButtonText = mainMenuPromoJson["largeButton"]["Text"].GetString(); - m_sMainMenuPromoData.largeButtonUrl = mainMenuPromoJson["largeButton"]["Url"].GetString(); - m_sMainMenuPromoData.largeButtonImageIndex = mainMenuPromoJson["largeButton"]["ImageIndex"].GetInt(); - - m_sMainMenuPromoData.smallButton1Title = mainMenuPromoJson["smallButton1"]["Title"].GetString(); - m_sMainMenuPromoData.smallButton1Url = mainMenuPromoJson["smallButton1"]["Url"].GetString(); - m_sMainMenuPromoData.smallButton1ImageIndex = mainMenuPromoJson["smallButton1"]["ImageIndex"].GetInt(); - - m_sMainMenuPromoData.smallButton2Title = mainMenuPromoJson["smallButton2"]["Title"].GetString(); - m_sMainMenuPromoData.smallButton2Url = mainMenuPromoJson["smallButton2"]["Url"].GetString(); - m_sMainMenuPromoData.smallButton2ImageIndex = mainMenuPromoJson["smallButton2"]["ImageIndex"].GetInt(); - - m_bHasMainMenuPromoData = true; } else { - spdlog::error("Failed requesting main menu promos: error {}", curl_easy_strerror(result)); + auto const err = res.error(); + spdlog::error("Failed reading masterserver main menu promos response:: {}", httplib::to_string(err)); m_bSuccessfullyConnected = false; } }); - requestThread.detach(); + request_thread.detach(); } -void MasterServerManager::AuthenticateWithOwnServer(const char* uid, const char* playerToken) +void MasterServerManager::AuthenticateWithOwnServer(const char* uid, const std::string& playerToken) { - // dont wait, just stop if we're trying to do 2 auth requests at once + + // don't wait, just stop if we're trying to do 2 auth requests at once if (m_bAuthenticatingWithGameServer || g_pVanillaCompatibility->GetVanillaCompatibility()) return; + m_sAuthFailureReason = "No error message provided"; + m_sAuthFailureMessage = "No error message provided"; + m_bAuthenticatingWithGameServer = true; m_bScriptAuthenticatingWithGameServer = true; m_bSuccessfullyAuthenticatedWithGameServer = false; - m_sAuthFailureReason = "Authentication Failed"; - std::string uidStr(uid); std::string tokenStr(playerToken); - std::thread requestThread( [this, uidStr, tokenStr]() { - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - - std::string readBuffer; - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format("{}/client/auth_with_self?id={}&playerToken={}", Cvar_ns_masterserver_hostname->GetString(), uidStr, tokenStr) - .c_str()); - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "POST"); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - - CURLcode result = curl_easy_perform(curl); - ScopeGuard cleanup( - [&] - { - m_bAuthenticatingWithGameServer = false; - m_bScriptAuthenticatingWithGameServer = false; - - if (m_bNewgameAfterSelfAuth) - { - // pretty sure this is threadsafe? - Cbuf_AddText(Cbuf_GetCurrentPlayer(), "ns_end_reauth_and_leave_to_lobby", cmd_source_t::kCommandSrcCode); - m_bNewgameAfterSelfAuth = false; - } - - curl_easy_cleanup(curl); - }); - - if (result == CURLcode::CURLE_OK) + const std::string querystring = + fmt::format("/client/auth_with_self?id={}&playerToken={}", encode_query_param(uidStr), encode_query_param(tokenStr)); + httplib::Client cli = SetupHttpClient(); + auto res = cli.Post(querystring); + if (res && res->status == 200) { m_bSuccessfullyConnected = true; - - rapidjson_document authInfoJson; - authInfoJson.Parse(readBuffer.c_str()); - - if (authInfoJson.HasParseError()) + try { - spdlog::error( - "Failed reading masterserver authentication response: encountered parse error \"{}\"", - rapidjson::GetParseError_En(authInfoJson.GetParseError())); - return; - } + // spdlog::info("{}", res->status); + nlohmann::json authInfoJson = nlohmann::json::parse(res->body); + RemoteAuthData newAuthData; - if (!authInfoJson.IsObject()) - { - spdlog::error("Failed reading masterserver authentication response: root object is not an object"); - return; - } + //newAuthData.uid = authInfoJson.at("id"); + authInfoJson.at("id").get_to(newAuthData.uid); - if (authInfoJson.HasMember("error")) - { - spdlog::error("Failed reading masterserver response: got fastify error response"); - spdlog::error(readBuffer); + const std::string original = authInfoJson.at("persistentData"); + newAuthData.pdata = base64_decode(original); - if (authInfoJson["error"].HasMember("msg")) - m_sAuthFailureReason = authInfoJson["error"]["msg"].GetString(); - else if (authInfoJson["error"].HasMember("enum")) - m_sAuthFailureReason = authInfoJson["error"]["enum"].GetString(); - else - m_sAuthFailureReason = "No error message provided"; + std::lock_guard guard(g_pServerAuthentication->m_AuthDataMutex); + g_pServerAuthentication->m_RemoteAuthenticationData.clear(); + g_pServerAuthentication->m_RemoteAuthenticationData.insert(std::make_pair(authInfoJson.at("authToken"), newAuthData)); - return; + m_bSuccessfullyAuthenticatedWithGameServer = true; + m_bAuthenticatingWithGameServer = false; + m_bScriptAuthenticatingWithGameServer = false; } - - if (!authInfoJson["success"].IsTrue()) + catch (nlohmann::json::parse_error& e) { - spdlog::error("Authentication with masterserver failed: \"success\" is not true"); - return; + spdlog::error("Failed authenticating with local server: encountered parse error \"{}\"", e.what()); + m_bSuccessfullyAuthenticatedWithGameServer = false; + m_bScriptAuthenticatingWithGameServer = false; + m_bAuthenticatingWithGameServer = false; } - - if (!authInfoJson.HasMember("success") || !authInfoJson.HasMember("id") || !authInfoJson["id"].IsString() || - !authInfoJson.HasMember("authToken") || !authInfoJson["authToken"].IsString() || - !authInfoJson.HasMember("persistentData") || !authInfoJson["persistentData"].IsArray()) + catch (nlohmann::json::out_of_range& e) { - spdlog::error("Failed reading masterserver authentication response: malformed json object"); - return; - } - - RemoteAuthData newAuthData {}; - strncpy_s(newAuthData.uid, sizeof(newAuthData.uid), authInfoJson["id"].GetString(), sizeof(newAuthData.uid) - 1); - - newAuthData.pdataSize = authInfoJson["persistentData"].GetArray().Size(); - newAuthData.pdata = new char[newAuthData.pdataSize]; - // memcpy(newAuthData.pdata, authInfoJson["persistentData"].GetString(), newAuthData.pdataSize); - - int i = 0; - // note: persistentData is a uint8array because i had problems getting strings to behave, it sucks but it's just how it be - // unfortunately potentially refactor later - for (auto& byte : authInfoJson["persistentData"].GetArray()) - { - if (!byte.IsUint() || byte.GetUint() > 255) - { - spdlog::error("Failed reading masterserver authentication response: malformed json object"); - return; - } - - newAuthData.pdata[i++] = static_cast(byte.GetUint()); + spdlog::error("Failed authenticating with local server: encountered data error \"{}\"", e.what()); + m_bSuccessfullyAuthenticatedWithGameServer = false; + m_bScriptAuthenticatingWithGameServer = false; + m_bAuthenticatingWithGameServer = false; } - - std::lock_guard guard(g_pServerAuthentication->m_AuthDataMutex); - g_pServerAuthentication->m_RemoteAuthenticationData.clear(); - g_pServerAuthentication->m_RemoteAuthenticationData.insert( - std::make_pair(authInfoJson["authToken"].GetString(), newAuthData)); - - m_bSuccessfullyAuthenticatedWithGameServer = true; } else { - spdlog::error("Failed authenticating with own server: error {}", curl_easy_strerror(result)); + auto err = res.error(); + spdlog::error("Failed authenticating with local server: encountered connection error {}", httplib::to_string(err)); m_bSuccessfullyConnected = false; m_bSuccessfullyAuthenticatedWithGameServer = false; m_bScriptAuthenticatingWithGameServer = false; } + + m_bAuthenticatingWithGameServer = false; + m_bScriptAuthenticatingWithGameServer = false; + + if (m_bNewgameAfterSelfAuth) + { + // pretty sure this is threadsafe? + Cbuf_AddText(Cbuf_GetCurrentPlayer(), "ns_end_reauth_and_leave_to_lobby", cmd_source_t::kCommandSrcCode); + m_bNewgameAfterSelfAuth = false; + } }); requestThread.detach(); } - -void MasterServerManager::AuthenticateWithServer(const char* uid, const char* playerToken, RemoteServerInfo server, const char* password) +void MasterServerManager::AuthenticateWithServer( + const char* uid, const char* playerToken, const char* serverId, const char* password) { // dont wait, just stop if we're trying to do 2 auth requests at once if (m_bAuthenticatingWithGameServer || g_pVanillaCompatibility->GetVanillaCompatibility()) return; + m_sAuthFailureReason = "No error message provided"; + m_sAuthFailureMessage = "No error message provided"; m_bAuthenticatingWithGameServer = true; m_bScriptAuthenticatingWithGameServer = true; m_bSuccessfullyAuthenticatedWithGameServer = false; - m_sAuthFailureReason = "Authentication Failed"; - std::string uidStr(uid); - std::string tokenStr(playerToken); - std::string serverIdStr(server.id); - std::string passwordStr(password); + std::string uid_str(uid); + const std::string& token_str(playerToken); + const std::string& server_id_str(serverId); + std::string password_str(password); std::thread requestThread( - [this, uidStr, tokenStr, serverIdStr, passwordStr, server]() + [this, uid_str, token_str, server_id_str, password_str]() { - // esnure that any persistence saving is done, so we know masterserver has newest - while (m_bSavingPersistentData) + // ensure that any persistence saving is done, so we know masterserver has newest + while (m_sPlayerPersistenceStates.contains(uid_str)) Sleep(100); - spdlog::info("Attempting authentication with server of id \"{}\"", serverIdStr); - - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); + spdlog::info("Attempting authentication with server of id \"{}\"", server_id_str); - std::string readBuffer; - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "POST"); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + httplib::Client cli = SetupHttpClient(); + const std::string querystring = fmt::format( + "/client/auth_with_server?id={}&playerToken={}&server={}&password={}", + encode_query_param(uid_str), + encode_query_param(token_str), + encode_query_param(server_id_str), + encode_query_param(password_str)); + auto res = cli.Post(querystring); - { - char* escapedPassword = curl_easy_escape(curl, passwordStr.c_str(), (int)passwordStr.length()); - - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format( - "{}/client/auth_with_server?id={}&playerToken={}&server={}&password={}", - Cvar_ns_masterserver_hostname->GetString(), - uidStr, - tokenStr, - serverIdStr, - escapedPassword) - .c_str()); - - curl_free(escapedPassword); - } - - CURLcode result = curl_easy_perform(curl); - ScopeGuard cleanup( - [&] - { - m_bAuthenticatingWithGameServer = false; - m_bScriptAuthenticatingWithGameServer = false; - curl_easy_cleanup(curl); - }); - - if (result == CURLcode::CURLE_OK) + if (res && res->status == 200) { m_bSuccessfullyConnected = true; - - rapidjson_document connectionInfoJson; - connectionInfoJson.Parse(readBuffer.c_str()); - - if (connectionInfoJson.HasParseError()) - { - spdlog::error( - "Failed reading masterserver authentication response: encountered parse error \"{}\"", - rapidjson::GetParseError_En(connectionInfoJson.GetParseError())); - return; - } - - if (!connectionInfoJson.IsObject()) - { - spdlog::error("Failed reading masterserver authentication response: root object is not an object"); - return; - } - - if (connectionInfoJson.HasMember("error")) + try { - spdlog::error("Failed reading masterserver response: got fastify error response"); - spdlog::error(readBuffer); - - if (connectionInfoJson["error"].HasMember("msg")) - m_sAuthFailureReason = connectionInfoJson["error"]["msg"].GetString(); - else if (connectionInfoJson["error"].HasMember("enum")) - m_sAuthFailureReason = connectionInfoJson["error"]["enum"].GetString(); + nlohmann::json connection_info_json = nlohmann::json::parse(res->body); + if (connection_info_json.at("success") == true) + { + // spdlog::info("[auth_with_server] body: {}", res->body); + m_pendingConnectionInfo.ip.S_un.S_addr = inet_addr(std::string(connection_info_json.at("ip")).c_str()); + m_pendingConnectionInfo.port = static_cast(connection_info_json.at("port")); + m_pendingConnectionInfo.authToken = connection_info_json.at("authToken"); + + m_bHasPendingConnectionInfo = true; + m_bSuccessfullyAuthenticatedWithGameServer = true; + m_bScriptAuthenticatingWithGameServer = false; + m_bAuthenticatingWithGameServer = false; + } else - m_sAuthFailureReason = "No error message provided"; - - return; + { + m_sAuthFailureReason = connection_info_json.at("error").at("enum"); + m_sAuthFailureMessage = connection_info_json.at("error").at("msg"); + m_bSuccessfullyAuthenticatedWithGameServer = false; + m_bScriptAuthenticatingWithGameServer = false; + m_bAuthenticatingWithGameServer = false; + } } - - if (!connectionInfoJson["success"].IsTrue()) + catch (nlohmann::json::parse_error& e) { - spdlog::error("Authentication with masterserver failed: \"success\" is not true"); - return; + spdlog::error("Failed authenticating with server: encountered parse error \"{}\"", e.what()); } - - if (!connectionInfoJson.HasMember("success") || !connectionInfoJson.HasMember("ip") || - !connectionInfoJson["ip"].IsString() || !connectionInfoJson.HasMember("port") || - !connectionInfoJson["port"].IsNumber() || !connectionInfoJson.HasMember("authToken") || - !connectionInfoJson["authToken"].IsString()) + catch (nlohmann::json::out_of_range& e) { - spdlog::error("Failed reading masterserver authentication response: malformed json object"); - return; + spdlog::error("Failed authenticating with server: encountered data error \"{}\"", e.what()); } - - m_pendingConnectionInfo.ip.S_un.S_addr = inet_addr(connectionInfoJson["ip"].GetString()); - m_pendingConnectionInfo.port = (unsigned short)connectionInfoJson["port"].GetUint(); - - strncpy_s( - m_pendingConnectionInfo.authToken, - sizeof(m_pendingConnectionInfo.authToken), - connectionInfoJson["authToken"].GetString(), - sizeof(m_pendingConnectionInfo.authToken) - 1); - - m_bHasPendingConnectionInfo = true; - m_bSuccessfullyAuthenticatedWithGameServer = true; - - m_currentServer = server; - m_sCurrentServerPassword = passwordStr; } else { - spdlog::error("Failed authenticating with server: error {}", curl_easy_strerror(result)); + spdlog::error( + "Failed authenticating with server: {}", + res.error() == httplib::Error::Success ? fmt::format("{} {}", res->status, res->body) + : std::to_string(static_cast(res.error()))); m_bSuccessfullyConnected = false; m_bSuccessfullyAuthenticatedWithGameServer = false; m_bScriptAuthenticatingWithGameServer = false; + m_bAuthenticatingWithGameServer = false; } }); requestThread.detach(); } -void MasterServerManager::WritePlayerPersistentData(const char* playerId, const char* pdata, size_t pdataSize) +void MasterServerManager::WritePlayerPersistentData(const char* player_id, const char* pdata, size_t pdata_size) { + std::string strPlayerId(player_id); // still call this if we don't have a server id, since lobbies that aren't port forwarded need to be able to call it - m_bSavingPersistentData = true; - if (!pdataSize) - { - spdlog::warn("attempted to write pdata of size 0!"); - return; - } - - std::string strPlayerId(playerId); - std::string strPdata(pdata, pdataSize); - - std::thread requestThread( - [this, strPlayerId, strPdata, pdataSize] - { - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - - std::string readBuffer; - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format( - "{}/accounts/write_persistence?id={}&serverId={}", - Cvar_ns_masterserver_hostname->GetString(), - strPlayerId, - m_sOwnServerId) - .c_str()); - curl_easy_setopt(curl, CURLOPT_POST, 1L); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - - curl_mime* mime = curl_mime_init(curl); - curl_mimepart* part = curl_mime_addpart(mime); - - curl_mime_data(part, strPdata.c_str(), pdataSize); - curl_mime_name(part, "pdata"); - curl_mime_filename(part, "file.pdata"); - curl_mime_type(part, "application/octet-stream"); - - curl_easy_setopt(curl, CURLOPT_MIMEPOST, mime); - - CURLcode result = curl_easy_perform(curl); - - if (result == CURLcode::CURLE_OK) - m_bSuccessfullyConnected = true; - else - m_bSuccessfullyConnected = false; - - curl_easy_cleanup(curl); - - m_bSavingPersistentData = false; - }); - - requestThread.detach(); -} - -void MasterServerManager::ProcessConnectionlessPacketSigreq1(std::string data) -{ - rapidjson_document obj; - obj.Parse(data); - - if (obj.HasParseError()) + if (m_sPlayerPersistenceStates.contains(strPlayerId)) { - // note: it's okay to print the data as-is since we've already checked that it actually came from Atlas - spdlog::error("invalid Atlas connectionless packet request ({}): {}", data, GetParseError_En(obj.GetParseError())); + spdlog::warn("player {} attempted to write pdata while previous request still exists!", strPlayerId); + // player is already requesting for leave, ignore the request. return; } - - if (!obj.HasMember("type") || !obj["type"].IsString()) + if (!pdata_size) { - spdlog::error("invalid Atlas connectionless packet request ({}): missing type", data); + spdlog::warn("player {} attempted to write pdata of size 0!", strPlayerId); return; } - std::string type = obj["type"].GetString(); - - if (type == "connect") - { - if (!obj.HasMember("token") || !obj["token"].IsString()) - { - spdlog::error("failed to handle Atlas connect request: missing or invalid connection token field"); - return; - } - std::string token = obj["token"].GetString(); - - if (!m_handledServerConnections.contains(token)) - m_handledServerConnections.insert(token); - else - return; // already handled - - spdlog::info("handling Atlas connect request {}", data); - - if (!obj.HasMember("uid") || !obj["uid"].IsUint64()) - { - spdlog::error("failed to handle Atlas connect request {}: missing or invalid uid field", token); - return; - } - uint64_t uid = obj["uid"].GetUint64(); - - std::string username; - if (obj.HasMember("username") && obj["username"].IsString()) - username = obj["username"].GetString(); + m_PlayerPersistenceMutex.lock(); + m_sPlayerPersistenceStates.insert(strPlayerId); + m_PlayerPersistenceMutex.unlock(); - std::string reject; - if (!g_pBanSystem->IsUIDAllowed(uid)) - reject = "Banned from this server."; + std::vector str_pdata(pdata_size); + memcpy(str_pdata.data(), pdata, pdata_size); - std::string pdata; - if (reject == "") + std::thread request_thread( + [this, strPlayerId, str_pdata, pdata_size] { - spdlog::info("getting pdata for connection {} (uid={} username={})", token, uid, username); - - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format("{}/server/connect?serverId={}&token={}", Cvar_ns_masterserver_hostname->GetString(), m_sOwnServerId, token) - .c_str()); - - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &pdata); - - CURLcode result = curl_easy_perform(curl); - if (result != CURLcode::CURLE_OK) - { - spdlog::error("failed to make Atlas connect pdata request {}: {}", token, curl_easy_strerror(result)); - curl_easy_cleanup(curl); - return; - } - - long respStatus = -1; - curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &respStatus); - - curl_easy_cleanup(curl); - - if (respStatus != 200) + spdlog::info("[Pdata] Writing persistence for user: {}", strPlayerId); + httplib::Client cli = SetupHttpClient(); + const std::string querystring = fmt::format( + "/accounts/write_persistence?id={}&serverId={}", encode_query_param(strPlayerId), encode_query_param(m_sOwnServerId)); + const std::string encoded = base64_encode(str_pdata.data(), pdata_size); + auto res = cli.Post(querystring, encoded, "text/plain"); + if (res != nullptr) { - rapidjson_document obj; - obj.Parse(pdata.c_str()); - - if (!obj.HasParseError() && obj.HasMember("error") && obj["error"].IsObject()) - spdlog::error( - "failed to make Atlas connect pdata request {}: response status {}, error: {} ({})", - token, - respStatus, - ((obj["error"].HasMember("enum") && obj["error"]["enum"].IsString()) ? obj["error"]["enum"].GetString() : ""), - ((obj["error"].HasMember("msg") && obj["error"]["msg"].IsString()) ? obj["error"]["msg"].GetString() : "")); + if (res->status == 200) + { + spdlog::info("[Pdata] Successfully wrote pdata for user: {}", strPlayerId); + m_bSuccessfullyConnected = true; + } else - spdlog::error("failed to make Atlas connect pdata request {}: response status {}", token, respStatus); - return; - } - - if (!pdata.length()) - { - spdlog::error("failed to make Atlas connect pdata request {}: pdata response is empty", token); - return; - } - - if (pdata.length() > PERSISTENCE_MAX_SIZE) - { - spdlog::error( - "failed to make Atlas connect pdata request {}: pdata is too large (max={} len={})", - token, - PERSISTENCE_MAX_SIZE, - pdata.length()); - return; - } - } - - if (reject == "") - spdlog::info("accepting connection {} (uid={} username={}) with {} bytes of pdata", token, uid, username, pdata.length()); - else - spdlog::info("rejecting connection {} (uid={} username={}) with reason \"{}\"", token, uid, username, reject); - - if (reject == "") - g_pServerAuthentication->AddRemotePlayer(token, uid, username, pdata); - - { - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); + { + auto err = res->body; + spdlog::error("[Pdata] Write persistence failed for user: {}, error: {}", strPlayerId, err); - char* rejectEnc = curl_easy_escape(curl, reject.c_str(), (int)reject.length()); - if (!rejectEnc) - { - spdlog::error("failed to handle Atlas connect request {}: failed to escape reject", token); - return; + m_bSuccessfullyConnected = true; + } } - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format( - "{}/server/connect?serverId={}&token={}&reject={}", - Cvar_ns_masterserver_hostname->GetString(), - m_sOwnServerId, - token, - rejectEnc) - .c_str()); - curl_free(rejectEnc); - - // note: we don't actually have any POST data, so we can't use CURLOPT_POST or the behavior is undefined (e.g., hangs in wine) - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "POST"); - - std::string buf; - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &buf); - - CURLcode result = curl_easy_perform(curl); - if (result != CURLcode::CURLE_OK) + else { - spdlog::error("failed to respond to Atlas connect request {}: {}", token, curl_easy_strerror(result)); - curl_easy_cleanup(curl); - return; - } + auto err = res.error(); + spdlog::error("[Pdata] Write persistence failed for user: {}, error: {}", strPlayerId, httplib::to_string(err)); - long respStatus = -1; - curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &respStatus); - - curl_easy_cleanup(curl); - - if (respStatus != 200) - { - rapidjson_document obj; - obj.Parse(buf.c_str()); - - if (!obj.HasParseError() && obj.HasMember("error") && obj["error"].IsObject()) - spdlog::error( - "failed to respond to Atlas connect request {}: response status {}, error: {} ({})", - token, - respStatus, - ((obj["error"].HasMember("enum") && obj["error"]["enum"].IsString()) ? obj["error"]["enum"].GetString() : ""), - ((obj["error"].HasMember("msg") && obj["error"]["msg"].IsString()) ? obj["error"]["msg"].GetString() : "")); - else - spdlog::error("failed to respond to Atlas connect request {}: response status {}", token, respStatus); - return; + m_bSuccessfullyConnected = false; } - } - return; - } + m_PlayerPersistenceMutex.lock(); + m_sPlayerPersistenceStates.erase(strPlayerId); + m_PlayerPersistenceMutex.unlock(); + }); - spdlog::error("invalid Atlas connectionless packet request: unknown type {}", type); + request_thread.detach(); } void ConCommand_ns_fetchservers(const CCommand& args) { - NOTE_UNUSED(args); g_pMasterServerManager->RequestServerList(); } MasterServerManager::MasterServerManager() - : m_pendingConnectionInfo {} - , m_sOwnServerId {""} + : m_sOwnServerId {""} , m_sOwnClientAuthToken {""} + , m_pendingConnectionInfo {} { } ON_DLL_LOAD_RELIESON("engine.dll", MasterServer, (ConCommand, ServerPresence), (CModule module)) { g_pMasterServerManager = new MasterServerManager; - + Cvar_ns_server_reg_token = new ConVar("ns_server_reg_token", "0", FCVAR_GAMEDLL, "Server account string used for registration"); Cvar_ns_masterserver_hostname = new ConVar("ns_masterserver_hostname", "127.0.0.1", FCVAR_NONE, ""); + Cvar_ns_matchmaker_hostname = new ConVar("ns_matchmaker_hostname", "127.0.0.1", FCVAR_NONE, ""); Cvar_ns_curl_log_enable = new ConVar("ns_curl_log_enable", "0", FCVAR_NONE, "Whether curl should log to the console"); RegisterConCommand("ns_fetchservers", ConCommand_ns_fetchservers, "Fetch all servers from the masterserver", FCVAR_CLIENTDLL); - MasterServerPresenceReporter* presenceReporter = new MasterServerPresenceReporter; - g_pServerPresence->AddPresenceReporter(presenceReporter); + auto* presence_reporter = new MasterServerPresenceReporter; + g_pServerPresence->AddPresenceReporter(presence_reporter); } void MasterServerPresenceReporter::CreatePresence(const ServerPresence* pServerPresence) { - NOTE_UNUSED(pServerPresence); m_nNumRegistrationAttempts = 0; } void MasterServerPresenceReporter::ReportPresence(const ServerPresence* pServerPresence) { - // make a copy of presence for multithreading purposes - ServerPresence threadedPresence(pServerPresence); + // make a copy of presence for multi threading purposes + ServerPresence threaded_presence(pServerPresence); if (!*g_pMasterServerManager->m_sOwnServerId) { @@ -1058,7 +968,6 @@ void MasterServerPresenceReporter::ReportPresence(const ServerPresence* pServerP void MasterServerPresenceReporter::DestroyPresence(const ServerPresence* pServerPresence) { - NOTE_UNUSED(pServerPresence); // Don't call this if we don't have a server id. if (!*g_pMasterServerManager->m_sOwnServerId) { @@ -1068,39 +977,21 @@ void MasterServerPresenceReporter::DestroyPresence(const ServerPresence* pServer // Not bothering with better thread safety in this case since DestroyPresence() is called when the game is shutting down. *g_pMasterServerManager->m_sOwnServerId = 0; - std::thread requestThread( - [this] - { - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - - std::string readBuffer; - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "DELETE"); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format( - "{}/server/remove_server?id={}", Cvar_ns_masterserver_hostname->GetString(), g_pMasterServerManager->m_sOwnServerId) - .c_str()); - - CURLcode result = curl_easy_perform(curl); - curl_easy_cleanup(curl); - }); - - requestThread.detach(); + httplib::Client cli = SetupHttpClient(); + const std::string querystring = fmt::format( + "/server/remove_server?id={}?serverAuthToken={}", + encode_query_param(g_pMasterServerManager->m_sOwnServerId), + encode_query_param(g_pMasterServerManager->m_sOwnServerAuthToken)); + cli.Delete(querystring); } void MasterServerPresenceReporter::RunFrame(double flCurrentTime, const ServerPresence* pServerPresence) { - NOTE_UNUSED(flCurrentTime); - NOTE_UNUSED(pServerPresence); // Check if we're already running an InternalAddServer() call in the background. // If so, grab the result if it's ready. if (addServerFuture.valid()) { - std::future_status status = addServerFuture.wait_for(0ms); + const std::future_status status = addServerFuture.wait_for(0ms); if (status != std::future_status::ready) { // Still running, no need to do anything. @@ -1108,23 +999,23 @@ void MasterServerPresenceReporter::RunFrame(double flCurrentTime, const ServerPr } // Check the result. - auto resultData = addServerFuture.get(); + const auto result_data = addServerFuture.get(); - g_pMasterServerManager->m_bSuccessfullyConnected = resultData.result != MasterServerReportPresenceResult::FailedNoConnect; + g_pMasterServerManager->m_bSuccessfullyConnected = result_data.result != MasterServerReportPresenceResult::FailedNoConnect; - switch (resultData.result) + switch (result_data.result) { case MasterServerReportPresenceResult::Success: // Copy over the server id and auth token granted by the MS. strncpy_s( g_pMasterServerManager->m_sOwnServerId, sizeof(g_pMasterServerManager->m_sOwnServerId), - resultData.id.value().c_str(), + result_data.id.value().c_str(), sizeof(g_pMasterServerManager->m_sOwnServerId) - 1); strncpy_s( g_pMasterServerManager->m_sOwnServerAuthToken, sizeof(g_pMasterServerManager->m_sOwnServerAuthToken), - resultData.serverAuthToken.value().c_str(), + result_data.serverAuthToken.value().c_str(), sizeof(g_pMasterServerManager->m_sOwnServerAuthToken) - 1); break; case MasterServerReportPresenceResult::FailedNoRetry: @@ -1144,39 +1035,38 @@ void MasterServerPresenceReporter::RunFrame(double flCurrentTime, const ServerPr if (m_nNumRegistrationAttempts >= MAX_REGISTRATION_ATTEMPTS) { - spdlog::log( - IsDedicatedServer() ? spdlog::level::level_enum::err : spdlog::level::level_enum::warn, - "Reached max ms server registration attempts."); + spdlog::error("Reached max ms server registration attempts."); } } else if (updateServerFuture.valid()) { // Check if the InternalUpdateServer() call completed. - std::future_status status = updateServerFuture.wait_for(0ms); + const std::future_status status = updateServerFuture.wait_for(0ms); + if (status != std::future_status::ready) { // Still running, no need to do anything. return; } - auto resultData = updateServerFuture.get(); - if (resultData.result == MasterServerReportPresenceResult::Success) + const auto result_data = updateServerFuture.get(); + if (result_data.result == MasterServerReportPresenceResult::Success) { - if (resultData.id) + if (result_data.id) { strncpy_s( g_pMasterServerManager->m_sOwnServerId, sizeof(g_pMasterServerManager->m_sOwnServerId), - resultData.id.value().c_str(), + result_data.id.value().c_str(), sizeof(g_pMasterServerManager->m_sOwnServerId) - 1); } - if (resultData.serverAuthToken) + if (result_data.serverAuthToken) { strncpy_s( g_pMasterServerManager->m_sOwnServerAuthToken, sizeof(g_pMasterServerManager->m_sOwnServerAuthToken), - resultData.serverAuthToken.value().c_str(), + result_data.serverAuthToken.value().c_str(), sizeof(g_pMasterServerManager->m_sOwnServerAuthToken) - 1); } } @@ -1185,299 +1075,161 @@ void MasterServerPresenceReporter::RunFrame(double flCurrentTime, const ServerPr void MasterServerPresenceReporter::InternalAddServer(const ServerPresence* pServerPresence) { - const ServerPresence threadedPresence(pServerPresence); + const ServerPresence threaded_presence(pServerPresence); // Never call this with an ongoing InternalAddServer() call. assert(!addServerFuture.valid()); g_pMasterServerManager->m_sOwnServerId[0] = 0; g_pMasterServerManager->m_sOwnServerAuthToken[0] = 0; - std::string modInfo = g_pMasterServerManager->m_sOwnModInfoJson; + std::string mod_info = g_pMasterServerManager->m_sOwnModInfoJson; std::string hostname = Cvar_ns_masterserver_hostname->GetString(); + std::string server_account = Cvar_ns_server_reg_token->GetString(); spdlog::info("Attempting to register the local server to the master server."); addServerFuture = std::async( std::launch::async, - [threadedPresence, modInfo, hostname, pServerPresence] + [threaded_presence, mod_info, hostname, server_account] { - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - - std::string readBuffer; - curl_easy_setopt(curl, CURLOPT_POST, 1L); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - - curl_mime* mime = curl_mime_init(curl); - curl_mimepart* part = curl_mime_addpart(mime); + httplib::Client cli = SetupHttpClient(); + std::string querystring = fmt::format( + "/server/" + "add_server?port={}&authPort={}&name={}&description={}&map={}&playlist={}&maxPlayers={}&password={}&serverRegToken=" + "{}&isMmServer={}", + threaded_presence.m_iPort, + threaded_presence.m_iAuthPort, + encode_query_param(threaded_presence.m_sServerName), + encode_query_param(threaded_presence.m_sServerDesc), + encode_query_param(threaded_presence.m_MapName), + encode_query_param(threaded_presence.m_PlaylistName), + threaded_presence.m_iMaxPlayers, + encode_query_param(threaded_presence.m_Password), + encode_query_param(server_account), + CommandLine()->CheckParm("-matchmaking") ? "true" : "false"); // Lambda to quickly cleanup resources and return a value. - auto ReturnCleanup = - [curl, mime](MasterServerReportPresenceResult result, const char* id = "", const char* serverAuthToken = "") + auto return_cleanup = + [](const MasterServerReportPresenceResult result, const std::string& id = "", const std::string& server_auth_token = "") { - curl_easy_cleanup(curl); - curl_mime_free(mime); - MasterServerPresenceReporter::ReportPresenceResultData data; data.result = result; data.id = id; - data.serverAuthToken = serverAuthToken; - + data.serverAuthToken = server_auth_token; return data; }; + // spdlog::info("{}", mod_info); + auto res = cli.Post(querystring, mod_info, "application/json"); - // don't log errors if we wouldn't actually show up in the server list anyway (stop tickets) - // except for dedis, for which this error logging is actually pretty important - bool shouldLogError = IsDedicatedServer() || (!strstr(pServerPresence->m_MapName, "mp_lobby") && - strstr(pServerPresence->m_PlaylistName, "private_match")); - - curl_mime_data(part, modInfo.c_str(), modInfo.size()); - curl_mime_name(part, "modinfo"); - curl_mime_filename(part, "modinfo.json"); - curl_mime_type(part, "application/json"); - - curl_easy_setopt(curl, CURLOPT_MIMEPOST, mime); - - // format every paramter because computers hate me + if (res && res->status == 200) { - char* nameEscaped = curl_easy_escape(curl, threadedPresence.m_sServerName.c_str(), 0); - char* descEscaped = curl_easy_escape(curl, threadedPresence.m_sServerDesc.c_str(), 0); - char* mapEscaped = curl_easy_escape(curl, threadedPresence.m_MapName, 0); - char* playlistEscaped = curl_easy_escape(curl, threadedPresence.m_PlaylistName, 0); - char* passwordEscaped = curl_easy_escape(curl, threadedPresence.m_Password, 0); - - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format( - "{}/server/" - "add_server?port={}&authPort=udp&name={}&description={}&map={}&playlist={}&maxPlayers={}&password={}", - hostname.c_str(), - threadedPresence.m_iPort, - nameEscaped, - descEscaped, - mapEscaped, - playlistEscaped, - threadedPresence.m_iMaxPlayers, - passwordEscaped) - .c_str()); - - curl_free(nameEscaped); - curl_free(descEscaped); - curl_free(mapEscaped); - curl_free(playlistEscaped); - curl_free(passwordEscaped); - } - - CURLcode result = curl_easy_perform(curl); - - if (result == CURLcode::CURLE_OK) - { - rapidjson_document serverAddedJson; - serverAddedJson.Parse(readBuffer.c_str()); - - // If we could not parse the JSON or it isn't an object, assume the MS is either wrong or we're completely out of date. - // No retry. - if (serverAddedJson.HasParseError()) - { - if (shouldLogError) - spdlog::error( - "Failed reading masterserver authentication response: encountered parse error \"{}\"", - rapidjson::GetParseError_En(serverAddedJson.GetParseError())); - return ReturnCleanup(MasterServerReportPresenceResult::FailedNoRetry); - } - - if (!serverAddedJson.IsObject()) + try { - if (shouldLogError) - spdlog::error("Failed reading masterserver authentication response: root object is not an object"); - return ReturnCleanup(MasterServerReportPresenceResult::FailedNoRetry); - } - - // Log request id for easier debugging when combining with logs on masterserver - if (serverAddedJson.HasMember("id")) - { - spdlog::info("Request id: {}", serverAddedJson["id"].GetString()); - } - else - { - spdlog::error("Couldn't find request id in response"); - } - - if (serverAddedJson.HasMember("error")) - { - if (shouldLogError) + nlohmann::json server_added_json = nlohmann::json::parse(res->body); + if (server_added_json["success"]) { - spdlog::error("Failed reading masterserver response: got fastify error response"); - spdlog::error(readBuffer); + spdlog::info("Successfully registered the local server to the master server."); + return return_cleanup( + MasterServerReportPresenceResult::Success, server_added_json.at("id"), server_added_json.at("serverAuthToken")); } - - // If this is DUPLICATE_SERVER, we'll retry adding the server every 20 seconds. - // The master server will only update its internal server list and clean up dead servers on certain events. - // And then again, only if a player requests the server list after the cooldown (1 second by default), or a server is - // added/updated/removed. In any case this needs to be fixed in the master server rewrite. - if (serverAddedJson["error"].HasMember("enum") && - strcmp(serverAddedJson["error"]["enum"].GetString(), "DUPLICATE_SERVER") == 0) + else { - if (shouldLogError) + if (!strcmp(std::string(server_added_json.at("error").at("enum")).c_str(), "DUPLICATE_SERVER")) + { spdlog::error("Cooling down while the master server cleans the dead server entry, if any."); - return ReturnCleanup(MasterServerReportPresenceResult::FailedDuplicateServer); + return return_cleanup(MasterServerReportPresenceResult::FailedDuplicateServer); + } + return return_cleanup(MasterServerReportPresenceResult::Failed); } - - // Retry until we reach max retries. - return ReturnCleanup(MasterServerReportPresenceResult::Failed); } - - if (!serverAddedJson["success"].IsTrue()) + catch (nlohmann::json::parse_error& e) { - if (shouldLogError) - spdlog::error("Adding server to masterserver failed: \"success\" is not true"); - return ReturnCleanup(MasterServerReportPresenceResult::FailedNoRetry); + spdlog::error("Failed registering server: encountered parse error \"{}\"", e.what()); + return return_cleanup(MasterServerReportPresenceResult::FailedNoRetry); } - - if (!serverAddedJson.HasMember("id") || !serverAddedJson["id"].IsString() || - !serverAddedJson.HasMember("serverAuthToken") || !serverAddedJson["serverAuthToken"].IsString()) + catch (nlohmann::json::out_of_range& e) { - if (shouldLogError) - spdlog::error("Failed reading masterserver response: malformed json object"); - return ReturnCleanup(MasterServerReportPresenceResult::FailedNoRetry); + spdlog::error("Failed registering server: encountered data error \"{}\"", e.what()); + return return_cleanup(MasterServerReportPresenceResult::FailedNoRetry); } - - spdlog::info("Successfully registered the local server to the master server."); - return ReturnCleanup( - MasterServerReportPresenceResult::Success, - serverAddedJson["id"].GetString(), - serverAddedJson["serverAuthToken"].GetString()); } else { - if (shouldLogError) - spdlog::error("Failed adding self to server list: error {}", curl_easy_strerror(result)); - return ReturnCleanup(MasterServerReportPresenceResult::FailedNoConnect); + spdlog::error("Failed adding self to server list: error {}", std::to_string(static_cast(res.error()))); + if (!res->body.empty()) + spdlog::error("res:{}", res->body); + return return_cleanup(MasterServerReportPresenceResult::FailedNoConnect); } }); } void MasterServerPresenceReporter::InternalUpdateServer(const ServerPresence* pServerPresence) { - const ServerPresence threadedPresence(pServerPresence); + const ServerPresence threaded_presence(pServerPresence); // Never call this with an ongoing InternalUpdateServer() call. assert(!updateServerFuture.valid()); - const std::string serverId = g_pMasterServerManager->m_sOwnServerId; + const std::string server_id = g_pMasterServerManager->m_sOwnServerId; const std::string hostname = Cvar_ns_masterserver_hostname->GetString(); - const std::string modinfo = g_pMasterServerManager->m_sOwnModInfoJson; + const std::string mod_info = g_pMasterServerManager->m_sOwnModInfoJson; + const std::string server_account = Cvar_ns_server_reg_token->GetString(); updateServerFuture = std::async( std::launch::async, - [threadedPresence, serverId, hostname, modinfo] + [threaded_presence, server_id, hostname, mod_info, server_account] { - CURL* curl = curl_easy_init(); - SetCommonHttpClientOptions(curl); - // Lambda to quickly cleanup resources and return a value. - auto ReturnCleanup = [curl](MasterServerReportPresenceResult result, const char* id = "", const char* serverAuthToken = "") + auto return_cleanup = + [](const MasterServerReportPresenceResult result, const std::string& id = "", const std::string& server_auth_token = "") { - curl_easy_cleanup(curl); - MasterServerPresenceReporter::ReportPresenceResultData data; data.result = result; - if (id != nullptr) + if (!id.empty()) { data.id = id; } - if (serverAuthToken != nullptr) + if (!server_auth_token.empty()) { - data.serverAuthToken = serverAuthToken; + data.serverAuthToken = server_auth_token; } return data; }; - std::string readBuffer; - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "POST"); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteToStringBufferCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - curl_easy_setopt(curl, CURLOPT_VERBOSE, 0L); - - // send all registration info so we have all necessary info to reregister our server if masterserver goes down, - // without a restart this isn't threadsafe :terror: - { - char* nameEscaped = curl_easy_escape(curl, threadedPresence.m_sServerName.c_str(), 0); - char* descEscaped = curl_easy_escape(curl, threadedPresence.m_sServerDesc.c_str(), 0); - char* mapEscaped = curl_easy_escape(curl, threadedPresence.m_MapName, 0); - char* playlistEscaped = curl_easy_escape(curl, threadedPresence.m_PlaylistName, 0); - char* passwordEscaped = curl_easy_escape(curl, threadedPresence.m_Password, 0); - - curl_easy_setopt( - curl, - CURLOPT_URL, - fmt::format( - "{}/server/" - "update_values?id={}&port={}&authPort=udp&name={}&description={}&map={}&playlist={}&playerCount={}&" - "maxPlayers={}&password={}", - hostname.c_str(), - serverId.c_str(), - threadedPresence.m_iPort, - nameEscaped, - descEscaped, - mapEscaped, - playlistEscaped, - threadedPresence.m_iPlayerCount, - threadedPresence.m_iMaxPlayers, - passwordEscaped) - .c_str()); - - curl_free(nameEscaped); - curl_free(descEscaped); - curl_free(mapEscaped); - curl_free(playlistEscaped); - curl_free(passwordEscaped); - } - - curl_mime* mime = curl_mime_init(curl); - curl_mimepart* part = curl_mime_addpart(mime); - - curl_mime_data(part, modinfo.c_str(), modinfo.size()); - curl_mime_name(part, "modinfo"); - curl_mime_filename(part, "modinfo.json"); - curl_mime_type(part, "application/json"); - - curl_easy_setopt(curl, CURLOPT_MIMEPOST, mime); - - CURLcode result = curl_easy_perform(curl); - - if (result == CURLcode::CURLE_OK) + httplib::Client cli = SetupHttpClient(); + std::string querystring = fmt::format( + "/server/" + "update_values?id={}&port={}&authPort={}&name={}&description={}&map={}&playlist={}&playerCount={}&" + "maxPlayers={}&password={}&gameState={}&serverAuthToken={}", + server_id.c_str(), + threaded_presence.m_iPort, + threaded_presence.m_iAuthPort, + encode_query_param(threaded_presence.m_sServerName), + encode_query_param(threaded_presence.m_sServerDesc), + encode_query_param(threaded_presence.m_MapName), + encode_query_param(threaded_presence.m_PlaylistName), + threaded_presence.m_iPlayerCount, + threaded_presence.m_iMaxPlayers, + encode_query_param(threaded_presence.m_Password), + encode_query_param(std::to_string(g_pSQGameState->eGameState)), + encode_query_param(g_pMasterServerManager->m_sOwnServerAuthToken)); + auto res = cli.Post(querystring, mod_info, "application/json"); + std::string updated_id; + std::string updated_auth_token; + if (res && res->status == 200) { - rapidjson_document serverAddedJson; - serverAddedJson.Parse(readBuffer.c_str()); - - const char* updatedId = nullptr; - const char* updatedAuthToken = nullptr; - - if (!serverAddedJson.HasParseError() && serverAddedJson.IsObject()) - { - if (serverAddedJson.HasMember("id") && serverAddedJson["id"].IsString()) - { - updatedId = serverAddedJson["id"].GetString(); - } - - if (serverAddedJson.HasMember("serverAuthToken") && serverAddedJson["serverAuthToken"].IsString()) - { - updatedAuthToken = serverAddedJson["serverAuthToken"].GetString(); - } - } - - return ReturnCleanup(MasterServerReportPresenceResult::Success, updatedId, updatedAuthToken); + return return_cleanup(MasterServerReportPresenceResult::Success); } else { - spdlog::warn("Heartbeat failed with error {}", curl_easy_strerror(result)); - return ReturnCleanup(MasterServerReportPresenceResult::Failed); + spdlog::error( + "Failed during heartbeat request: {}", + res.error() == httplib::Error::Success ? fmt::format("{} {}", res->status, res->body) + : std::to_string(static_cast(res.error()))); + return return_cleanup(MasterServerReportPresenceResult::Failed); } }); } diff --git a/primedev/masterserver/masterserver.h b/primedev/masterserver/masterserver.h index 570db619f..c1eece070 100644 --- a/primedev/masterserver/masterserver.h +++ b/primedev/masterserver/masterserver.h @@ -4,16 +4,16 @@ #include "server/serverpresence.h" #include #include +#include #include #include +#include "scripts/scriptmatchmakingevents.h" #include - extern ConVar* Cvar_ns_masterserver_hostname; extern ConVar* Cvar_ns_curl_log_enable; struct RemoteModInfo { -public: std::string Name; std::string Version; }; @@ -21,14 +21,14 @@ struct RemoteModInfo class RemoteServerInfo { public: - char id[33]; // 32 bytes + nullterminator + char id[33]; // server info char name[64]; std::string description; char map[32]; char playlist[16]; - char region[32]; + int gameState; std::vector requiredMods; int playerCount; @@ -37,14 +37,13 @@ class RemoteServerInfo // connection stuff bool requiresPassword; -public: RemoteServerInfo( const char* newId, const char* newName, const char* newDescription, const char* newMap, const char* newPlaylist, - const char* newRegion, + int newGameState, int newPlayerCount, int newMaxPlayers, bool newRequiresPassword); @@ -52,16 +51,13 @@ class RemoteServerInfo struct RemoteServerConnectionInfo { -public: - char authToken[32]; - + std::string authToken; in_addr ip; unsigned short port; }; struct MainMenuPromoData { -public: std::string newInfoTitle1; std::string newInfoTitle2; std::string newInfoTitle3; @@ -83,23 +79,36 @@ struct MainMenuPromoData class MasterServerManager { private: + bool m_RequestingClantag = false; + bool m_RequestingRemoteBanlistVersion = false; + bool m_RequestingRemoteBanlist = false; bool m_bRequestingServerList = false; bool m_bAuthenticatingWithGameServer = false; public: + std::unordered_set m_sPlayerPersistenceStates; + std::mutex m_PlayerPersistenceMutex; + char m_sOwnServerId[33]; char m_sOwnServerAuthToken[33]; - char m_sOwnClientAuthToken[33]; + std::string m_sOwnClientAuthToken; std::string m_sOwnModInfoJson; + std::string RemoteBanlistString; + std::string LocalBanlistVersion = "undefined"; + std::string RemoteBanlistVersion; + bool m_bOriginAuthWithMasterServerDone = false; bool m_bOriginAuthWithMasterServerInProgress = false; - bool m_bOriginAuthWithMasterServerSuccessful = false; + bool m_bOriginAuthWithMasterServerSuccess = false; + std::string m_sOriginAuthWithMasterServerErrorCode = ""; std::string m_sOriginAuthWithMasterServerErrorMessage = ""; + + bool m_bSavingPersistentData = false; bool m_bScriptRequestingServerList = false; @@ -109,6 +118,7 @@ class MasterServerManager bool m_bScriptAuthenticatingWithGameServer = false; bool m_bSuccessfullyAuthenticatedWithGameServer = false; std::string m_sAuthFailureReason {}; + std::string m_sAuthFailureMessage {}; bool m_bHasPendingConnectionInfo = false; RemoteServerConnectionInfo m_pendingConnectionInfo; @@ -118,26 +128,30 @@ class MasterServerManager bool m_bHasMainMenuPromoData = false; MainMenuPromoData m_sMainMenuPromoData; - std::optional m_currentServer; - std::string m_sCurrentServerPassword; - - std::unordered_set m_handledServerConnections; - public: MasterServerManager(); - void ClearServerList(); void RequestServerList(); void RequestMainMenuPromos(); void AuthenticateOriginWithMasterServer(const char* uid, const char* originToken); - void AuthenticateWithOwnServer(const char* uid, const char* playerToken); - void AuthenticateWithServer(const char* uid, const char* playerToken, RemoteServerInfo server, const char* password); - void WritePlayerPersistentData(const char* playerId, const char* pdata, size_t pdataSize); - void ProcessConnectionlessPacketSigreq1(std::string req); + void AuthenticateWithOwnServer(const char* uid, const std::string& playerToken); + void AuthenticateWithServer(const char* uid, const char* playerToken, const char* serverId, const char* password); + bool AuthenticateWithMatchmakingServer( + RemoteServerConnectionInfo& conn_info, + const char* uid, + const std::string& playerToken, + const std::string& serverId, + const char* password); + void WritePlayerPersistentData(const char* player_id, const char* pdata, size_t pdata_size); + bool SetLocalPlayerClanTag(std::string clantag); + bool StartMatchmaking(MatchmakeInfo* status); + bool CancelMatchmaking(); + bool UpdateMatchmakingStatus(MatchmakeInfo* status); }; extern MasterServerManager* g_pMasterServerManager; extern ConVar* Cvar_ns_masterserver_hostname; +extern ConVar* Cvar_ns_matchmaker_hostname; /** Result returned in the std::future of a MasterServerPresenceReporter::ReportPresence() call. */ enum class MasterServerReportPresenceResult @@ -189,9 +203,10 @@ class MasterServerPresenceReporter : public ServerPresenceReporter // The future used for InternalAddServer() calls. std::future addServerFuture; - + std::thread addServerThread; // The future used for InternalAddServer() calls. std::future updateServerFuture; + std::thread updateServerThread; int m_nNumRegistrationAttempts; diff --git a/primedev/ns_version.h b/primedev/ns_version.h index d30594fbd..d4d367a79 100644 --- a/primedev/ns_version.h +++ b/primedev/ns_version.h @@ -2,6 +2,6 @@ #ifndef NORTHSTAR_VERSION // Turning off clang-format here so it doesn't mess with style as it needs to be this way for regex-ing with CI // clang-format off -#define NORTHSTAR_VERSION 0,0,0,1 +#define NORTHSTAR_VERSION 1,18,0,0 // clang-format on #endif diff --git a/primedev/primelauncher/main.cpp b/primedev/primelauncher/main.cpp index 96c96c047..bf3f1c168 100644 --- a/primedev/primelauncher/main.cpp +++ b/primedev/primelauncher/main.cpp @@ -11,6 +11,7 @@ #include #include +#include namespace fs = std::filesystem; @@ -27,6 +28,72 @@ HMODULE hTier0Module; wchar_t exePath[4096]; wchar_t buffer[8192]; +static std::string ConvertWideToANSI(const std::wstring& wstr) +{ + int count = WideCharToMultiByte(CP_ACP, 0, wstr.c_str(), wstr.length(), NULL, 0, NULL, NULL); + std::string str(count, 0); + WideCharToMultiByte(CP_ACP, 0, wstr.c_str(), -1, &str[0], count, NULL, NULL); + return str; +} + +static std::wstring ConvertAnsiToWide(const std::string& str) +{ + int count = MultiByteToWideChar(CP_ACP, 0, str.c_str(), str.length(), NULL, 0); + std::wstring wstr(count, 0); + MultiByteToWideChar(CP_ACP, 0, str.c_str(), str.length(), &wstr[0], count); + return wstr; +} + +static std::string ConvertWideToUtf8(const std::wstring& wstr) +{ + int count = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), wstr.length(), NULL, 0, NULL, NULL); + std::string str(count, 0); + WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, &str[0], count, NULL, NULL); + return str; +} + +static std::wstring ConvertUtf8ToWide(const std::string& str) +{ + int count = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), str.length(), NULL, 0); + std::wstring wstr(count, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), str.length(), &wstr[0], count); + return wstr; +} + +static std::string SanitizeEncodings(const char* buf) +{ + std::wstring ws = ConvertAnsiToWide(buf); + return ConvertWideToUtf8(ws); +} + +void RunUpdater() +{ + fs::path updater_path = std::filesystem::current_path() / L"NSCN_Updater.exe"; + // run updater when we don't have -updated present and updater exists + if (std::filesystem::exists(updater_path) && !strstr(GetCommandLineA(), "-updated")) + { + PROCESS_INFORMATION pi; + memset(&pi, 0, sizeof(pi)); + STARTUPINFO si; + memset(&si, 0, sizeof(si)); + si.cb = sizeof(STARTUPINFO); + si.dwFlags = STARTF_USESHOWWINDOW; + si.wShowWindow = SW_MINIMIZE; + CreateProcessW( + updater_path.c_str(), + NULL, + NULL, + NULL, + false, + CREATE_DEFAULT_ERROR_MODE | CREATE_NEW_PROCESS_GROUP, + NULL, + NULL, + (LPSTARTUPINFOW)&si, + &pi); + exit(0); + } +} + DWORD GetProcessByName(std::wstring processName) { HANDLE snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); @@ -74,13 +141,12 @@ FARPROC GetLauncherMain() void LibraryLoadError(DWORD dwMessageId, const wchar_t* libName, const wchar_t* location) { - char text[8192]; - std::string message = std::system_category().message(dwMessageId); + wchar_t text[8192]; + std::wstring message = ConvertUtf8ToWide(std::system_category().message(dwMessageId)); - sprintf_s( + swprintf_s( text, - "Failed to load the %ls at \"%ls\" (%lu):\n\n%hs\n\nMake sure you followed the Northstar installation instructions carefully " - "before reaching out for help.", + L"无法读取文件 %ls 于 \"%ls\" (%lu):\n\n%hs\n\n请检查安装流程是否正确,或在北极星CN Wiki中查询常见疑难问题解答", libName, location, dwMessageId, @@ -88,44 +154,46 @@ void LibraryLoadError(DWORD dwMessageId, const wchar_t* libName, const wchar_t* if (dwMessageId == 126 && std::filesystem::exists(location)) { - sprintf_s( + swprintf_s( text, - "%s\n\nThe file at the specified location DOES exist, so this error indicates that one of its *dependencies* failed to be " - "found.\n\nTry the following steps:\n1. Install Visual C++ 2022 Redistributable: " - "https://aka.ms/vs/17/release/vc_redist.x64.exe\n2. Repair game files", + L"%s\n\n该文件在文件系统中存在,但其所需的其他文件可能无法被正常读取\n\n请尝试以下可能得解决方案: \n1. 安装 Visual C++ 2022 " + L"运行库: https://aka.ms/vs/17/release/vc_redist.x64.exe \n2. 在您安装游戏的平台中校验游戏完整性", text); } else if (!fs::exists("Titanfall2.exe") && (fs::exists("..\\Titanfall2.exe") || fs::exists("..\\..\\Titanfall2.exe"))) { - auto curDir = std::filesystem::current_path().filename().string(); - auto aboveDir = std::filesystem::current_path().parent_path().filename().string(); - sprintf_s( + auto curDir = std::filesystem::current_path().filename().wstring(); + auto aboveDir = std::filesystem::current_path().parent_path().filename().wstring(); + swprintf_s( text, - "%s\n\nWe detected that in your case you have extracted the files into a *subdirectory* of your Titanfall 2 " - "installation.\nPlease move all the files and folders from current folder (\"%s\") into the Titanfall 2 installation directory " - "just above (\"%s\").\n\nPlease try out the above steps by yourself before reaching out to the community for support.", + L"%s\n\n检测到您将北极星CN的安装包解压到了 Titanfall 2 的次级文件夹中 " + L"\n请删除或撤销您在当前文件夹中的操作,并将北极星CN安装包解压到与'Titanfall2.exe'同级的文件夹中", text, curDir.c_str(), aboveDir.c_str()); } else if (!fs::exists("Titanfall2.exe")) { - sprintf_s( + swprintf_s( text, - "%s\n\nRemember: you need to unpack the contents of this archive into your Titanfall 2 game installation directory, not just " - "to any random folder.", + L"%s\n\n请注意: " + L"您需要将北极星CN安装包解压到游戏安装目录,与'Titanfall2.exe'同级的文件夹中,而不是直接在压缩包内运行或解压到任意位置!", text); } else if (fs::exists("Titanfall2.exe")) { - sprintf_s( + swprintf_s( text, - "%s\n\nTitanfall2.exe has been found in the current directory: is the game installation corrupted or did you not unpack all " - "Northstar files here?", + L"%s\n\n北极星CN安装位置正确,但游戏文件可能损坏或北极星CN文件缺失\n请尝试在游戏安装平台中校验文件完整性,或重新安装北极星CN并" + L"确保所有文件都被解压到当前目录中!", text); } - MessageBoxA(GetForegroundWindow(), text, "Northstar Launcher Error", 0); + int result = MessageBoxW(GetForegroundWindow(), text, L"启动北极星CN时出现错误", 0); + if (result == IDOK) + { + ShellExecuteW(NULL, L"open", L"https://wiki.northstar.cool/#/installing-northstar/troubleshooting", NULL, NULL, SW_SHOWNORMAL); + } } void AwaitOriginStartup() @@ -348,7 +416,7 @@ int main(int argc, char* argv[]) Sleep(100); } } - + RunUpdater(); if (!GetExePathWide(exePath, sizeof(exePath))) { MessageBoxA( @@ -466,11 +534,7 @@ int main(int argc, char* argv[]) std::cout << "[*] Launching the game..." << std::endl; auto LauncherMain = GetLauncherMain(); if (!LauncherMain) - MessageBoxA( - GetForegroundWindow(), - "Failed loading launcher.dll.\nThe game cannot continue and has to exit.", - "Northstar Launcher Error", - 0); + MessageBoxW(GetForegroundWindow(), L"无法找到 launcher.dll.\n启动游戏时失败", L"错误", 0); std::cout.flush(); return ((int(/*__fastcall*/*)(HINSTANCE, HINSTANCE, LPSTR, int))LauncherMain)( diff --git a/primedev/primelauncher/ns_icon.ico b/primedev/primelauncher/ns_icon.ico index fc9ad1661..977c5cb32 100644 Binary files a/primedev/primelauncher/ns_icon.ico and b/primedev/primelauncher/ns_icon.ico differ diff --git a/primedev/scripts/clantag.cpp b/primedev/scripts/clantag.cpp new file mode 100644 index 000000000..3dba2db7f --- /dev/null +++ b/primedev/scripts/clantag.cpp @@ -0,0 +1,54 @@ +//#include "pch.h" +#include "dedicated/dedicated.h" +#include +#include "masterserver/masterserver.h" +#include +#include "squirrel/squirrel.h" +#include "engine/r2engine.h" +#include "core/hooks.h" +#include + + +AUTOHOOK_INIT() + + +ADD_SQFUNC("bool", NSSetLocalPlayerClanTag, "string clantag", "", ScriptContext::UI) +{ + std::string clantag = g_pSquirrel->getstring(sqvm, 1); + bool result = g_pMasterServerManager->SetLocalPlayerClanTag(clantag); + g_pSquirrel->pushbool(sqvm, result); + return SQRESULT_NOTNULL; +} + + +AUTOHOOK(StryderShit,engine.dll + 0x1712F0, char*,__fastcall, + (__int64* a1,int a2, char * a3,int a4,__int64 a5,const char* a6, + char* a7,int a8,void* a9,int a10,int a11, char*Src, + int a13,__int64 a14,__int64 a15,__int64 a16,__int64 a17,int a18)) +{ + //if (Src != NULL) + //spdlog::info("||HEADER | {} ||",Src); + //if (a9 != NULL) + //spdlog::info("||BODY | {} || {} ||", a9,(char*)a9); + + //if (!strcmp(a7,"datacenters_v3_tchinese.txt")) + //{ + // std::string test = "test.northstar.cool"; + // spdlog::info("{} -----> {}", test, a7); + // return StryderShit(a1, a2, (char*)test.c_str(), a4, a5, a6, a7, a8, a9, a10, a11, Src, a13, a14, a15, a16, a17, a18); + //} + //spdlog::info("{} -----> {}", a3, a7); + return StryderShit(a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,Src,a13,a14,a15,a16,a17,a18); +} + +ON_DLL_LOAD_CLIENT("engine.dll", ClantagInitializeServer, (CModule module)) +{ + AUTOHOOK_DISPATCH(); +} +ON_DLL_LOAD_CLIENT_RELIESON("client.dll", ClantagInitializeClient, ClientSquirrel, (CModule module)) +{ + + if (IsDedicatedServer()) + return; + +} diff --git a/primedev/scripts/client/scriptoriginauth.cpp b/primedev/scripts/client/scriptoriginauth.cpp index 420c48720..af95249c7 100644 --- a/primedev/scripts/client/scriptoriginauth.cpp +++ b/primedev/scripts/client/scriptoriginauth.cpp @@ -22,13 +22,13 @@ ADD_SQFUNC("MasterServerAuthResult", NSGetMasterServerAuthResult, "", "", Script { g_pSquirrel->pushnewstructinstance(sqvm, 3); - g_pSquirrel->pushbool(sqvm, g_pMasterServerManager->m_bOriginAuthWithMasterServerSuccessful); + g_pSquirrel->pushbool(sqvm, g_pMasterServerManager->m_bOriginAuthWithMasterServerSuccess); g_pSquirrel->sealstructslot(sqvm, 0); - g_pSquirrel->pushstring(sqvm, g_pMasterServerManager->m_sOriginAuthWithMasterServerErrorCode.c_str(), -1); + g_pSquirrel->pushstring(sqvm, g_pMasterServerManager->m_sAuthFailureReason.c_str(), -1); g_pSquirrel->sealstructslot(sqvm, 1); - g_pSquirrel->pushstring(sqvm, g_pMasterServerManager->m_sOriginAuthWithMasterServerErrorMessage.c_str(), -1); + g_pSquirrel->pushstring(sqvm, g_pMasterServerManager->m_sAuthFailureMessage.c_str(), -1); g_pSquirrel->sealstructslot(sqvm, 2); return SQRESULT_NOTNULL; diff --git a/primedev/scripts/client/scriptserverbrowser.cpp b/primedev/scripts/client/scriptserverbrowser.cpp index b946f7a95..39cfdad19 100644 --- a/primedev/scripts/client/scriptserverbrowser.cpp +++ b/primedev/scripts/client/scriptserverbrowser.cpp @@ -65,8 +65,8 @@ ADD_SQFUNC("void", NSTryAuthWithServer, "int serverIndex, string password = ''", // do auth g_pMasterServerManager->AuthenticateWithServer( g_pLocalPlayerUserID, - g_pMasterServerManager->m_sOwnClientAuthToken, - g_pMasterServerManager->m_vRemoteServers[serverIndex], + g_pMasterServerManager->m_sOwnClientAuthToken.c_str(), + g_pMasterServerManager->m_vRemoteServers[serverIndex].id, (char*)password); return SQRESULT_NULL; @@ -97,7 +97,7 @@ ADD_SQFUNC("void", NSConnectToAuthedServer, "", "", ScriptContext::UI) // set auth token, then try to connect // i'm honestly not entirely sure how silentconnect works regarding ports and encryption so using connect for now - g_pCVar->FindVar("serverfilter")->SetValue(info.authToken); + g_pCVar->FindVar("serverfilter")->SetValue(info.authToken.c_str()); Cbuf_AddText( Cbuf_GetCurrentPlayer(), fmt::format( @@ -185,8 +185,8 @@ ADD_SQFUNC("array", NSGetGameServers, "", "", ScriptContext::UI) g_pSquirrel->pushbool(sqvm, remoteServer.requiresPassword); g_pSquirrel->sealstructslot(sqvm, 8); - // region - g_pSquirrel->pushstring(sqvm, remoteServer.region, -1); + // gamestate + g_pSquirrel->pushinteger(sqvm, remoteServer.gameState); g_pSquirrel->sealstructslot(sqvm, 9); // requiredMods diff --git a/primedev/scripts/scriptgamestate.cpp b/primedev/scripts/scriptgamestate.cpp new file mode 100644 index 000000000..57706b4ee --- /dev/null +++ b/primedev/scripts/scriptgamestate.cpp @@ -0,0 +1,43 @@ +#include "squirrel/squirrel.h" +#include "scriptgamestate.h" +#include "client/r2client.h" +#include "engine/r2engine.h" + +SQGameState* g_pSQGameState = new SQGameState; + + +std::string GameStateToString(int gamestate) +{ + switch (gamestate) + { + case 0: + return "WaitingForCustomStart"; + case 1: + return "WaitingForPlayers"; + case 2: + return "PickLoadout"; + case 3: + return "Prematch"; + case 4: + return "Playing"; + case 5: + return "SuddenDeath"; + case 6: + return "SwitchingSides"; + case 7: + return "WinnerDetermined"; + case 8: + return "Epilogue"; + case 9: + return "Postmatch"; + } + return "none"; +} + +ADD_SQFUNC("void", NSUpdateSQGameState, "int gamestate", "", ScriptContext::SERVER) +{ + int state = g_pSquirrel->getinteger(sqvm, 1); + g_pSQGameState->eGameState = state; + spdlog::info("GAMESTATE: {} STRING: {}", state, GameStateToString(state)); + return SQRESULT_NOTNULL; +} \ No newline at end of file diff --git a/primedev/scripts/scriptgamestate.h b/primedev/scripts/scriptgamestate.h new file mode 100644 index 000000000..c96391735 --- /dev/null +++ b/primedev/scripts/scriptgamestate.h @@ -0,0 +1,9 @@ +#pragma once +#include +class SQGameState +{ + public: + int eGameState; +}; + +extern SQGameState* g_pSQGameState; diff --git a/primedev/scripts/scriptmasterservermessages.cpp b/primedev/scripts/scriptmasterservermessages.cpp new file mode 100644 index 000000000..b9be170fb --- /dev/null +++ b/primedev/scripts/scriptmasterservermessages.cpp @@ -0,0 +1,22 @@ +#include "pch.h" +#include "squirrel/squirrel.h" +#include "masterserver/masterserver.h" +#include "server/auth/serverauthentication.h" +#include "core/hooks.h" +#include "client/r2client.h" +#include "scriptmasterservermessages.h" + +MasterserverMessenger* g_pMasterserverMessenger = new MasterserverMessenger; + +ADD_SQFUNC("string", NSGetLastMasterserverMessage, "", "", ScriptContext::SERVER) +{ + if (g_pMasterserverMessenger->m_vQueuedMasterserverMessages.empty()) + { + g_pSquirrel->pushstring(sqvm, "none", -1); + return SQRESULT_NOTNULL; + } + std::string thismessage = g_pMasterserverMessenger->m_vQueuedMasterserverMessages.front(); + g_pSquirrel->pushstring(sqvm, thismessage.c_str(), -1); // message content + g_pMasterserverMessenger->m_vQueuedMasterserverMessages.pop(); + return SQRESULT_NOTNULL; +} diff --git a/primedev/scripts/scriptmasterservermessages.h b/primedev/scripts/scriptmasterservermessages.h new file mode 100644 index 000000000..5c5eaa959 --- /dev/null +++ b/primedev/scripts/scriptmasterservermessages.h @@ -0,0 +1,10 @@ +#pragma once +#include + +class MasterserverMessenger +{ + public: + std::queue m_vQueuedMasterserverMessages; +}; + +extern MasterserverMessenger* g_pMasterserverMessenger; diff --git a/primedev/scripts/scriptmatchmakingevents.cpp b/primedev/scripts/scriptmatchmakingevents.cpp new file mode 100644 index 000000000..9683b907e --- /dev/null +++ b/primedev/scripts/scriptmatchmakingevents.cpp @@ -0,0 +1,237 @@ +#include "pch.h" +#include "squirrel/squirrel.h" +#include "masterserver/masterserver.h" +#include "server/auth/serverauthentication.h" +#include "core/hooks.h" +#include "client/r2client.h" +#include "scriptmasterservermessages.h" +#include "scriptmatchmakingevents.h" +#include "core/tier0.h" +#include "core/vanilla.h" + +#define NSCN_MATCHMAKING + +AUTOHOOK_INIT() + +MatchmakeManager* g_pMatchmakerManager = new MatchmakeManager; + +struct MatchmakeStatus_Baseline +{ + // this is used when game request for status before we had a response form masterserver. + std::string status = "#MATCH_NOTHING"; + std::string playlistName = "ps"; + std::string etaSeconds = "30"; + std::string mapIdx = "1"; + std::string modeIdx = "1"; + std::string PlaylistList = "ps,aitdm"; +}; + +std::string MatchmakeInfo::GetByParam(int idx) +{ + switch (idx) + { + case 1: + return this->playlistName; + case 2: + return this->etaSeconds; + case 3: + return this->mapIdx; + case 4: + return this->modeIdx; + case 5: + return this->playlistListstr; + } + return "none"; +} + +MatchmakeManager::MatchmakeManager() +{ + // Create the request thread for use later on + std::thread* requestthreadptr = new std::thread( + [this]() + { + while (true) + { + switch (LocalState) + { + case 0: // idle + break; + case 1: // new matchmaking event + if (!g_pMasterServerManager->StartMatchmaking(info)) + { + LocalState = 0; + break; + } + else + { + LocalState = 2; + break; + } + + case 2: // matchmaking + if (!g_pMasterServerManager->UpdateMatchmakingStatus(info)) + { + LocalState = 0; + break; + } + if (!strcmp(info->status.c_str(), "#MATCHMAKING_MATCH_CONNECTING")) + { + for (auto& pair : g_pServerAuthentication->m_PlayerAuthenticationData) + g_pServerAuthentication->WritePersistentData(pair.first); + + + // matchmake done, connect to the server + g_pMasterServerManager->AuthenticateWithServer( + g_pLocalPlayerUserID, g_pMasterServerManager->m_sOwnClientAuthToken.c_str(), info->serverId.c_str(), + ""); + while (g_pMasterServerManager->m_bScriptAuthenticatingWithGameServer) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + RemoteServerConnectionInfo& conn_info = g_pMasterServerManager->m_pendingConnectionInfo; + if (!g_pMasterServerManager->m_bHasPendingConnectionInfo) + { + spdlog::error( + "[Matchmaking] Failed while authenticating with matchmaking server: {} {}", + g_pMasterServerManager->m_sAuthFailureReason, + g_pMasterServerManager->m_sAuthFailureMessage); + LocalState = 0; + break; + // TODO: cleanup + } + std::string connection_cmd = fmt::format( + "connect {}.{}.{}.{}:{}", + conn_info.ip.S_un.S_un_b.s_b1, + conn_info.ip.S_un.S_un_b.s_b2, + conn_info.ip.S_un.S_un_b.s_b3, + conn_info.ip.S_un.S_un_b.s_b4, + conn_info.port); + spdlog::info("Connect: {}", connection_cmd); + g_pCVar->FindVar("serverfilter")->SetValue(conn_info.authToken.c_str()); + Cbuf_AddText( + Cbuf_GetCurrentPlayer(), + connection_cmd.c_str(), + cmd_source_t::kCommandSrcCode); + + g_pMasterServerManager->m_bHasPendingConnectionInfo = false; + LocalState = 0; + } + break; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + }); + requestthreadptr->detach(); + this->requestthread = requestthreadptr; + + // Creating a status object + MatchmakeInfo* statusptr = new MatchmakeInfo; + statusptr->mapIdx = 1; + statusptr->modeIdx = 1; + statusptr->status = ""; + this->info = statusptr; +} +void MatchmakeManager::StartMatchmake(std::string playlistlist) +{ + if (this->LocalState == 2) + { + spdlog::warn("[Matchmaker] calling StartMatchmake while matchmaking!"); + return; + } + spdlog::info("[Matchmaker] Starting Matchmake"); + this->info->playlistListstr = playlistlist; + std::vector playlistvec; + std::stringstream ss(playlistlist); + while (ss.good()) + { + std::string substr; + std::getline(ss, substr, ','); + spdlog::info("[Matchmaker] Adding playlist: {} to playlist_list", substr); + playlistvec.push_back(substr); + } + this->info->playlistList = playlistvec; + LocalState = 1; +} +void MatchmakeManager::CancelMatchmake() +{ + if (this->LocalState != 2) + { + spdlog::warn("[Matchmaker] calling StopMatchmake while not matchmaking!"); + return; + } + spdlog::info("[Matchmaker] Cancelling Matchmake"); + LocalState = 0; + g_pMasterServerManager->CancelMatchmaking(); +} + +#ifdef NSCN_MATCHMAKING +// clang-format off +AUTOHOOK(CCLIENT__StartMatchmaking, client.dll + 0x213D00, SQRESULT, __fastcall, (HSQUIRRELVM clientsqvm)) +// clang-format on +{ + const char* str = g_pSquirrel->getstring(clientsqvm, 1); + spdlog::info("[Matchmaker] Staring Matchmaking event:{}", str); + g_pMatchmakerManager->StartMatchmake(str); + + return SQRESULT_NOTNULL; +} + +// clang-format off +AUTOHOOK(CCLIENT__CancelMatchmaking, client.dll + 0x211640, SQRESULT, __fastcall, (HSQUIRRELVM clientsqvm)) +// clang-format on +{ + spdlog::info("[Matchmaker] Cancelled Matchmaking request"); + g_pMatchmakerManager->CancelMatchmake(); + + return SQRESULT_NOTNULL; +} + +// clang-format off +AUTOHOOK(CCLIENT__AreWeMatchmaking, client.dll + 0x211970, SQRESULT, __fastcall, (HSQUIRRELVM clientsqvm)) +// clang-format on +{ + g_pSquirrel->pushbool(clientsqvm, (g_pMatchmakerManager->LocalState == 2)); + return SQRESULT_NOTNULL; +} +// clang-format off +AUTOHOOK(CCLIENT__GetMyMatchmakingStatusParam, client.dll + 0x3B3570, SQRESULT, __fastcall, (HSQUIRRELVM clientsqvm)) +// clang-format on +{ + // DummyMatchmakingParamData* dummydata = new DummyMatchmakingParamData; + int param = g_pSquirrel->getinteger(clientsqvm, 1); + switch (param) + { + case 1: + g_pSquirrel->pushstring(clientsqvm, g_pMatchmakerManager->info->playlistName.c_str(), -1); + case 2: + g_pSquirrel->pushstring(clientsqvm, g_pMatchmakerManager->info->etaSeconds.c_str(), -1); + case 3: + g_pSquirrel->pushstring(clientsqvm, g_pMatchmakerManager->info->mapIdx.c_str(), -1); + case 4: + g_pSquirrel->pushstring(clientsqvm, g_pMatchmakerManager->info->modeIdx.c_str(), -1); + case 5: + g_pSquirrel->pushstring(clientsqvm, g_pMatchmakerManager->info->playlistListstr.c_str(), -1); + } + + return SQRESULT_NOTNULL; +} +// clang-format off +AUTOHOOK(CCLIENT__GetMyMatchmakingStatus, client.dll + 0x3B1B70, SQRESULT, __fastcall, (HSQUIRRELVM clientsqvm)) +// clang-format on +{ + g_pSquirrel->pushstring(clientsqvm, g_pMatchmakerManager->info->status.c_str(), -1); + return SQRESULT_NOTNULL; +} + +ON_DLL_LOAD_CLIENT_RELIESON("client.dll", ScriptMatchmakingEvents, ClientSquirrel, (CModule module)) +{ + if (g_pVanillaCompatibility->GetVanillaCompatibility()) + { + spdlog::info("[Matchmaker] NSCN Matchmaker is disabled! (found -vanilla)"); + return; + } + AUTOHOOK_DISPATCH(); +} +#endif diff --git a/primedev/scripts/scriptmatchmakingevents.h b/primedev/scripts/scriptmatchmakingevents.h new file mode 100644 index 000000000..d97a7c1ec --- /dev/null +++ b/primedev/scripts/scriptmatchmakingevents.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include + +struct MatchmakeConnectionInfo +{ + public: + char authToken[32]; + in_addr ip; + unsigned short port; +}; + +const std::string MatchMakingStatus[] = { + "#MATCH_NOTHING", + "#MATCHMAKING_SEARCHING_FOR_MATCH", + "#MATCHMAKING_QUEUED", + "#MATCHMAKING_ALLOCATING_SERVER", + "#MATCHMAKING_MATCH_CONNECTING"}; + +class MatchmakeInfo +{ + public: + std::string status; + std::string playlistName; + std::string etaSeconds; + std::string mapIdx; + std::string modeIdx; + std::string playlistListstr; + std::vector playlistList; + std::string timeout; + bool serverReady; + std::string serverId; + std::string GetByParam(int idx); +}; + +class MatchmakeManager +{ + private: + std::thread* requestthread; + + public: + /* + * state: + * 0: idle + * 1: new matchmake event + * 2: matchmaking + */ + MatchmakeInfo* info; + int LocalState = 0; + MatchmakeManager(); + void StartMatchmake(std::string playlistlist); + void CancelMatchmake(); +}; + +extern MatchmakeManager* g_pMatchmakerManager; diff --git a/primedev/scripts/scriptsvm.cpp b/primedev/scripts/scriptsvm.cpp new file mode 100644 index 000000000..10b0849a3 --- /dev/null +++ b/primedev/scripts/scriptsvm.cpp @@ -0,0 +1,131 @@ +#include "squirrel/squirrel.h" +#include "server/svm.h" + +const float COORD_MIN = -16384.0; +const float COORD_LEN = 32767.0; + +struct SvmPoint +{ + Vector3 origin; + int faction; +}; + +void releasehook(void* val, int size) +{ + svm_free_and_destroy_model((svm_model**)val); +} + +void svm_print_string(const char* str) +{ +} + +ADD_SQFUNC( + "userdata", + NSSvmTrain, + "array points", + "train svm model based on provided points", + ScriptContext::CLIENT | ScriptContext::SERVER | ScriptContext::UI) +{ + SQArray* originsArray = sqvm->_stackOfCurrentFunction[1]._VAL.asArray; + std::vector points; + + for (int vIdx = 0; vIdx < originsArray->_usedSlots; ++vIdx) + { + if (originsArray->_values[vIdx]._Type == OT_TABLE) + { + SQTable* originTable = originsArray->_values[vIdx]._VAL.asTable; + SvmPoint newpoint; + for (int idx = 0; idx < originTable->_numOfNodes; ++idx) + { + SQTable::_HashNode* node = &originTable->_nodes[idx]; + + if (node->val._Type == OT_VECTOR) + { + SQVector* v = (SQVector*)node; + newpoint.origin = Vector3(v->x,v->y, v->z); + } + if (node->val._Type == OT_INTEGER) + { + newpoint.faction = node->val._VAL.as64Integer; + } + } + points.push_back(newpoint); + } + } + + svm_parameter param; + param.svm_type = C_SVC; + param.kernel_type = RBF; + param.degree = 3; + param.gamma = 1.0/3; + param.coef0 = 0; + param.nu = 0.5; + param.cache_size = 100; + param.C = 1000; + param.eps = 1e-3; + param.p = 0.1; + param.shrinking = 1; + param.probability = 0; + param.nr_weight = 0; + param.weight_label = NULL; + param.weight = NULL; + + svm_problem prob; + prob.l = points.size(); + prob.y = new double[prob.l]; + svm_node* x_space = new svm_node[4 * prob.l]; + prob.x = new svm_node*[prob.l]; + + for (int i = 0; i < points.size(); ++i) + { + x_space[4 * i].index = 1; + x_space[4 * i].value = ((points[i].origin.x - COORD_MIN) / COORD_LEN); + x_space[4 * i + 1].index = 2; + x_space[4 * i + 1].value = ((points[i].origin.y - COORD_MIN) / COORD_LEN); + x_space[4 * i + 2].index = 3; + x_space[4 * i + 2].value = ((points[i].origin.z - COORD_MIN) / COORD_LEN); + x_space[4 * i + 3].index = -1; + prob.x[i] = &x_space[4 * i]; + prob.y[i] = points[i].faction; + } + svm_set_print_string_function(svm_print_string); + svm_model* model = svm_train(&prob, ¶m); + svm_model** userdata_model = g_pSquirrel->template createuserdata(sqvm, 8); + *userdata_model = model; + + SQUserData* userdata = (SQUserData*)(((uintptr_t)userdata_model) - 80); + userdata->releasehook = releasehook; + SQObject* object = new SQObject; + object->_Type = OT_USERDATA; + object->structNumber = 0; + object->_VAL.asUserdata = userdata; + g_pSquirrel->pushobject(sqvm, object); + return SQRESULT_NOTNULL; +} + +ADD_SQFUNC( + "int", + NSSvmPredict, + "userdata model, vector origin", + "predict things", + ScriptContext::CLIENT | ScriptContext::SERVER | ScriptContext::UI) +{ + SQObject* obj = new SQObject; + svm_model** model = nullptr; + g_pSquirrel->__sq_getobject(sqvm, 1, obj); + model = (svm_model**)obj->_VAL.asUserdata->data; + Vector3 vec =g_pSquirrel->getvector(sqvm, 2); + + svm_node x[4]; + x[0].index = 1; + x[0].value = (vec.x - COORD_MIN) / COORD_LEN; + x[1].index = 2; + x[1].value = (vec.y - COORD_MIN) / COORD_LEN; + x[2].index = 3; + x[2].value = (vec.z - COORD_MIN) / COORD_LEN; + x[3].index = -1; + // code + int result = svm_predict(*model, x); + g_pSquirrel->pushinteger(sqvm, result); + return SQRESULT_NOTNULL; +} diff --git a/primedev/server/auth/serverauthentication.cpp b/primedev/server/auth/serverauthentication.cpp index 58268bcfa..5212689cb 100644 --- a/primedev/server/auth/serverauthentication.cpp +++ b/primedev/server/auth/serverauthentication.cpp @@ -6,6 +6,7 @@ #include "server/serverpresence.h" #include "engine/hoststate.h" #include "bansystem.h" +#include "util/base64.h" #include "core/convar/concommand.h" #include "dedicated/dedicated.h" #include "config/profile.h" @@ -13,29 +14,119 @@ #include "engine/r2engine.h" #include "client/r2client.h" #include "server/r2server.h" +#include "scripts/scriptmasterservermessages.h" +#include "cpp-httplib/httplib.h" +#include "nlohmann/json.hpp" +#include "shared/playlist.h" #include #include -#include #include +const char* AUTHSERVER_VERIFY_STRING = "I am a northstar server!"; + // global vars ServerAuthenticationManager* g_pServerAuthentication; CBaseServer__RejectConnectionType CBaseServer__RejectConnection; -void ServerAuthenticationManager::AddRemotePlayer(std::string token, uint64_t uid, std::string username, std::string pdata) +void ServerAuthenticationManager::StartPlayerAuthServer() { - std::string uidS = std::to_string(uid); + if (m_bRunningPlayerAuthThread) + { + spdlog::warn("ServerAuthenticationManager::StartPlayerAuthServer was called while m_bRunningPlayerAuthThread is true"); + return; + } - RemoteAuthData newAuthData {}; - strncpy_s(newAuthData.uid, sizeof(newAuthData.uid), uidS.c_str(), uidS.length()); - strncpy_s(newAuthData.username, sizeof(newAuthData.username), username.c_str(), username.length()); - newAuthData.pdata = new char[pdata.length()]; - newAuthData.pdataSize = pdata.length(); - memcpy(newAuthData.pdata, pdata.c_str(), newAuthData.pdataSize); + g_pServerPresence->SetAuthPort(Cvar_ns_player_auth_port->GetInt()); // set auth port for presence + m_bRunningPlayerAuthThread = true; - std::lock_guard guard(m_AuthDataMutex); - m_RemoteAuthenticationData[token] = newAuthData; + // listen is a blocking call so thread this + std::thread serverThread( + [this] + { + // this is just a super basic way to verify that servers have ports open, masterserver will try to read this before ensuring + // server is legit + m_PlayerAuthServer.Post( + "/rui_message", + [this](const httplib::Request& request, httplib::Response& response) + { + if (!request.has_param("serverAuthToken") || + strcmp(g_pMasterServerManager->m_sOwnServerAuthToken, request.get_param_value("serverAuthToken").c_str())) + { + // return; + } + + g_pMasterserverMessenger->m_vQueuedMasterserverMessages.push(request.body); + + response.set_content("{\"success\":true}", "application/json"); + }); + + m_PlayerAuthServer.Get( + "/verify", + [](const httplib::Request& request, httplib::Response& response) + { response.set_content(AUTHSERVER_VERIFY_STRING, "text/plain"); }); + + m_PlayerAuthServer.Get( + "/status", + [](const httplib::Request& request, httplib::Response& response) + { + std::string result; + nlohmann::json json; + int maxplayers = 0; + auto p_maxplayers = R2::GetCurrentPlaylistVar("max_players", true); + if (p_maxplayers) + { + maxplayers = std::stoi(p_maxplayers); + } + json["maxplayers"] = maxplayers; + json["playercount"] = g_pServerAuthentication->m_PlayerAuthenticationData.size(); + result = json.dump(); + response.set_content(result, "application/json"); + }); + + m_PlayerAuthServer.Post( + "/authenticate_incoming_player", + [this](const httplib::Request& request, httplib::Response& response) + { + if (!request.has_param("id") || !request.has_param("authToken") || request.body.size() >= 128000 || + !request.has_param("serverAuthToken") || + strcmp(g_pMasterServerManager->m_sOwnServerAuthToken, request.get_param_value("serverAuthToken").c_str())) + { + response.set_content("{\"success\":false}", "application/json"); + return; + } + + RemoteAuthData newAuthData; + if (request.has_param("clantag")) + { + newAuthData.clantag = request.get_param_value("clantag"); + } + newAuthData.uid = request.get_param_value("id"); + newAuthData.username = request.get_param_value("username"); + newAuthData.pdata = base64_decode(request.body); + + std::lock_guard guard(m_AuthDataMutex); + m_RemoteAuthenticationData.insert(std::make_pair(request.get_param_value("authToken"), newAuthData)); + + response.set_content("{\"success\":true}", "application/json"); + }); + + m_PlayerAuthServer.listen("0.0.0.0", Cvar_ns_player_auth_port->GetInt()); + }); + + serverThread.detach(); +} + +void ServerAuthenticationManager::StopPlayerAuthServer() +{ + if (!m_bRunningPlayerAuthThread) + { + spdlog::warn("ServerAuthenticationManager::StopPlayerAuthServer was called while m_bRunningPlayerAuthThread is false"); + return; + } + + m_bRunningPlayerAuthThread = false; + m_PlayerAuthServer.stop(); } void ServerAuthenticationManager::AddPlayer(CBaseClient* pPlayer, const char* pToken) @@ -44,7 +135,7 @@ void ServerAuthenticationManager::AddPlayer(CBaseClient* pPlayer, const char* pT auto remoteAuthData = m_RemoteAuthenticationData.find(pToken); if (remoteAuthData != m_RemoteAuthenticationData.end()) - additionalData.pdataSize = remoteAuthData->second.pdataSize; + additionalData.pdataSize = remoteAuthData->second.pdata.size(); else additionalData.pdataSize = PERSISTENCE_MAX_SIZE; @@ -66,8 +157,8 @@ bool ServerAuthenticationManager::VerifyPlayerName(const char* pAuthToken, const // always use name from masterserver if available // use of strncpy_s here should verify that this is always nullterminated within valid buffer size auto authData = m_RemoteAuthenticationData.find(pAuthToken); - if (authData != m_RemoteAuthenticationData.end() && *authData->second.username) - strncpy_s(pOutVerifiedName, 64, authData->second.username, 63); + if (authData != m_RemoteAuthenticationData.end() && !authData->second.username.empty()) + strncpy_s(pOutVerifiedName, 64, authData->second.username.c_str(), 63); else strncpy_s(pOutVerifiedName, 64, pName, 63); @@ -118,7 +209,7 @@ bool ServerAuthenticationManager::CheckAuthentication(CBaseClient* pPlayer, uint std::lock_guard guard(m_AuthDataMutex); auto authData = m_RemoteAuthenticationData.find(pAuthToken); - if (authData != m_RemoteAuthenticationData.end() && !strcmp(sUid.c_str(), authData->second.uid)) + if (authData != m_RemoteAuthenticationData.end() && sUid == authData->second.uid) return true; return false; @@ -143,7 +234,7 @@ void ServerAuthenticationManager::AuthenticatePlayer(CBaseClient* pPlayer, uint6 if (!m_bForceResetLocalPlayerPersistence || strcmp(sUid.c_str(), g_pLocalPlayerUserID)) { // copy pdata into buffer - memcpy(pPlayer->m_PersistenceBuffer, authData->second.pdata, authData->second.pdataSize); + memcpy(pPlayer->m_PersistenceBuffer, authData->second.pdata.data(), authData->second.pdata.size()); } // set persistent data as ready @@ -170,13 +261,11 @@ bool ServerAuthenticationManager::RemovePlayerAuthData(CBaseClient* pPlayer) // we don't have our auth token at this point, so lookup authdata by uid for (auto& auth : m_RemoteAuthenticationData) { - if (!strcmp(pPlayer->m_UID, auth.second.uid)) + if (pPlayer->m_UID == auth.second.uid) { // pretty sure this is fine, since we don't iterate after the erase // i think if we iterated after it'd be undefined behaviour tho std::lock_guard guard(m_AuthDataMutex); - - delete[] auth.second.pdata; m_RemoteAuthenticationData.erase(auth.first); return true; } @@ -286,6 +375,15 @@ h_CBaseClient__Connect(CBaseClient* self, char* pName, void* pNetChannel, char b // we already know this player's authentication data is legit, actually write it to them now g_pServerAuthentication->AuthenticatePlayer(self, iNextPlayerUid, pNextPlayerToken); + if (!g_pServerAuthentication->m_RemoteAuthenticationData[pNextPlayerToken].clantag.empty()) + { + // std::string nameWithTag = "[" + g_pServerAuthentication->m_RemoteAuthenticationData[pNextPlayerToken].clantag + "]" + + // self->m_Name; strncpy_s(self->m_Name, nameWithTag.c_str(), 64); + char* clantag = ((char*)self + 0x318 + 64); + strncpy(clantag, g_pServerAuthentication->m_RemoteAuthenticationData[pNextPlayerToken].clantag.c_str(), 16); + clantag[15] = '\0'; + } + g_pServerAuthentication->AddPlayer(self, pNextPlayerToken); g_pServerLimits->AddPlayer(self); @@ -302,7 +400,7 @@ static void h_CBaseClient__ActivatePlayer(CBaseClient* self) { g_pServerAuthentication->m_bForceResetLocalPlayerPersistence = false; g_pServerAuthentication->WritePersistentData(self); - g_pServerPresence->SetPlayerCount((int)g_pServerAuthentication->m_PlayerAuthenticationData.size()); + g_pServerPresence->SetPlayerCount(g_pServerAuthentication->m_PlayerAuthenticationData.size()); } o_pCBaseClient__ActivatePlayer(self); @@ -342,7 +440,6 @@ static void h_CBaseClient__Disconnect(CBaseClient* self, uint32_t unknownButAlwa void ConCommand_ns_resetpersistence(const CCommand& args) { - NOTE_UNUSED(args); if (*g_pServerState == server_state_t::ss_active) { spdlog::error("ns_resetpersistence must be entered from the main menu"); @@ -369,6 +466,7 @@ ON_DLL_LOAD_RELIESON("engine.dll", ServerAuthentication, (ConCommand, ConVar), ( g_pServerAuthentication = new ServerAuthenticationManager; + g_pServerAuthentication->Cvar_ns_player_auth_port = new ConVar("ns_player_auth_port", "8081", FCVAR_GAMEDLL, ""); g_pServerAuthentication->Cvar_ns_erase_auth_info = new ConVar("ns_erase_auth_info", "1", FCVAR_GAMEDLL, "Whether auth info should be erased from this server on disconnect or crash"); g_pServerAuthentication->Cvar_ns_auth_allow_insecure = diff --git a/primedev/server/auth/serverauthentication.h b/primedev/server/auth/serverauthentication.h index 996d20e1c..0f1b75193 100644 --- a/primedev/server/auth/serverauthentication.h +++ b/primedev/server/auth/serverauthentication.h @@ -1,17 +1,16 @@ #pragma once #include "core/convar/convar.h" +#include "cpp-httplib/httplib.h" #include "engine/r2engine.h" #include #include struct RemoteAuthData { - char uid[33]; - char username[64]; - - // pdata - char* pdata; - size_t pdataSize; + std::string uid; + std::string username; + std::string clantag; + std::vector pdata; }; struct PlayerAuthenticationData @@ -26,7 +25,11 @@ extern CBaseServer__RejectConnectionType CBaseServer__RejectConnection; class ServerAuthenticationManager { +private: + httplib::Server m_PlayerAuthServer; + public: + ConVar* Cvar_ns_player_auth_port; ConVar* Cvar_ns_erase_auth_info; ConVar* Cvar_ns_auth_allow_insecure; ConVar* Cvar_ns_auth_allow_insecure_write; @@ -36,12 +39,14 @@ class ServerAuthenticationManager std::unordered_map m_PlayerAuthenticationData; bool m_bAllowDuplicateAccounts = false; + bool m_bRunningPlayerAuthThread = false; bool m_bNeedLocalAuthForNewgame = false; bool m_bForceResetLocalPlayerPersistence = false; bool m_bStartingLocalSPGame = false; public: - void AddRemotePlayer(std::string token, uint64_t uid, std::string username, std::string pdata); + void StartPlayerAuthServer(); + void StopPlayerAuthServer(); void AddPlayer(CBaseClient* pPlayer, const char* pAuthToken); void RemovePlayer(CBaseClient* pPlayer); diff --git a/primedev/server/servernethooks.cpp b/primedev/server/servernethooks.cpp deleted file mode 100644 index 148b735fa..000000000 --- a/primedev/server/servernethooks.cpp +++ /dev/null @@ -1,218 +0,0 @@ -#include "core/convar/convar.h" -#include "engine/r2engine.h" -#include "shared/exploit_fixes/ns_limits.h" -#include "masterserver/masterserver.h" - -#include -#include -#include - -AUTOHOOK_INIT() - -static ConVar* Cvar_net_debug_atlas_packet; -static ConVar* Cvar_net_debug_atlas_packet_insecure; - -static BCRYPT_ALG_HANDLE HMACSHA256; -constexpr size_t HMACSHA256_LEN = 256 / 8; - -static bool InitHMACSHA256() -{ - NTSTATUS status; - DWORD hashLength = 0; - ULONG hashLengthSz = 0; - - if ((status = BCryptOpenAlgorithmProvider(&HMACSHA256, BCRYPT_SHA256_ALGORITHM, NULL, BCRYPT_ALG_HANDLE_HMAC_FLAG))) - { - spdlog::error("failed to initialize HMAC-SHA256: BCryptOpenAlgorithmProvider: error 0x{:08X}", (ULONG)status); - return false; - } - - if ((status = BCryptGetProperty(HMACSHA256, BCRYPT_HASH_LENGTH, (PUCHAR)&hashLength, sizeof(hashLength), &hashLengthSz, 0))) - { - spdlog::error("failed to initialize HMAC-SHA256: BCryptGetProperty(BCRYPT_HASH_LENGTH): error 0x{:08X}", (ULONG)status); - return false; - } - - if (hashLength != HMACSHA256_LEN) - { - spdlog::error("failed to initialize HMAC-SHA256: BCryptGetProperty(BCRYPT_HASH_LENGTH): unexpected value {}", hashLength); - return false; - } - - return true; -} - -// compare the HMAC-SHA256(data, key) against sig (note: all strings are treated as raw binary data) -static bool VerifyHMACSHA256(std::string key, std::string sig, std::string data) -{ - uint8_t invalid = 1; - char hash[HMACSHA256_LEN]; - - NTSTATUS status; - BCRYPT_HASH_HANDLE h = NULL; - - if ((status = BCryptCreateHash(HMACSHA256, &h, NULL, 0, (PUCHAR)key.c_str(), (ULONG)key.length(), 0))) - { - spdlog::error("failed to verify HMAC-SHA256: BCryptCreateHash: error 0x{:08X}", (ULONG)status); - goto cleanup; - } - - if ((status = BCryptHashData(h, (PUCHAR)data.c_str(), (ULONG)data.length(), 0))) - { - spdlog::error("failed to verify HMAC-SHA256: BCryptHashData: error 0x{:08X}", (ULONG)status); - goto cleanup; - } - - if ((status = BCryptFinishHash(h, (PUCHAR)&hash, (ULONG)sizeof(hash), 0))) - { - spdlog::error("failed to verify HMAC-SHA256: BCryptFinishHash: error 0x{:08X}", (ULONG)status); - goto cleanup; - } - - // constant-time compare - if (sig.length() == sizeof(hash)) - { - invalid = 0; - for (size_t i = 0; i < sizeof(hash); i++) - invalid |= (uint8_t)(sig[i]) ^ (uint8_t)(hash[i]); - } - -cleanup: - if (h) - BCryptDestroyHash(h); - return !invalid; -} - -// v1 HMACSHA256-signed masterserver request (HMAC-SHA256(JSONData, MasterServerToken) + JSONData) -static void ProcessAtlasConnectionlessPacketSigreq1(netpacket_t* packet, bool dbg, std::string pType, std::string pData) -{ - if (pData.length() < HMACSHA256_LEN) - { - if (dbg) - spdlog::warn("ignoring Atlas connectionless packet (size={} type={}): invalid: too short for signature", packet->size, pType); - return; - } - - std::string pSig; // is binary data, not actually an ASCII string - pSig = pData.substr(0, HMACSHA256_LEN); - pData = pData.substr(HMACSHA256_LEN); - - if (!g_pMasterServerManager || !g_pMasterServerManager->m_sOwnServerAuthToken[0]) - { - if (dbg) - spdlog::warn( - "ignoring Atlas connectionless packet (size={} type={}): invalid (data={}): no masterserver token yet", - packet->size, - pType, - pData); - return; - } - - if (!VerifyHMACSHA256(std::string(g_pMasterServerManager->m_sOwnServerAuthToken), pSig, pData)) - { - if (!Cvar_net_debug_atlas_packet_insecure->GetBool()) - { - if (dbg) - spdlog::warn( - "ignoring Atlas connectionless packet (size={} type={}): invalid: invalid signature (key={})", - packet->size, - pType, - std::string(g_pMasterServerManager->m_sOwnServerAuthToken)); - return; - } - spdlog::warn( - "processing Atlas connectionless packet (size={} type={}) with invalid signature due to net_debug_atlas_packet_insecure", - packet->size, - pType); - } - - if (dbg) - spdlog::info("got Atlas connectionless packet (size={} type={} data={})", packet->size, pType, pData); - - std::thread t(&MasterServerManager::ProcessConnectionlessPacketSigreq1, g_pMasterServerManager, pData); - t.detach(); - - return; -} - -static void ProcessAtlasConnectionlessPacket(netpacket_t* packet) -{ - bool dbg = Cvar_net_debug_atlas_packet->GetBool(); - - // extract kind, null-terminated type, data - std::string pType, pData; - for (int i = 5; i < packet->size; i++) - { - if (packet->data[i] == '\x00') - { - pType.assign((char*)(&packet->data[5]), (size_t)(i - 5)); - if (i + 1 < packet->size) - pData.assign((char*)(&packet->data[i + 1]), (size_t)(packet->size - i - 1)); - break; - } - } - - // note: all Atlas connectionless packets should be idempotent so multiple attempts can be made to mitigate packet loss - // note: all long-running Atlas connectionless packet handlers should be started in a new thread (with copies of the data) to avoid - // blocking networking - - // v1 HMACSHA256-signed masterserver request - if (pType == "sigreq1") - { - ProcessAtlasConnectionlessPacketSigreq1(packet, dbg, pType, pData); - return; - } - - if (dbg) - spdlog::warn("ignoring Atlas connectionless packet (size={} type={}): unknown type", packet->size, pType); - return; -} - -AUTOHOOK(ProcessConnectionlessPacket, engine.dll + 0x117800, bool, , (void* a1, netpacket_t* packet)) -{ - // packet->data consists of 0xFFFFFFFF (int32 -1) to indicate packets aren't split, followed by a header consisting of a single - // character, which is used to uniquely identify the packet kind. Most kinds follow this with a null-terminated string payload - // then an arbitrary amoount of data. - - // T (no rate limits since we authenticate packets before doing anything expensive) - if (4 < packet->size && packet->data[4] == 'T') - { - ProcessAtlasConnectionlessPacket(packet); - return false; - } - - // check rate limits for the original unconnected packets - if (!g_pServerLimits->CheckConnectionlessPacketLimits(packet)) - return false; - - // A, H, I, N - return ProcessConnectionlessPacket(a1, packet); -} - -ON_DLL_LOAD_RELIESON("engine.dll", ServerNetHooks, ConVar, (CModule module)) -{ - AUTOHOOK_DISPATCH_MODULE(engine.dll) - - if (!InitHMACSHA256()) - throw std::runtime_error("failed to initialize bcrypt"); - - if (!VerifyHMACSHA256( - "test", - "\x88\xcd\x21\x08\xb5\x34\x7d\x97\x3c\xf3\x9c\xdf\x90\x53\xd7\xdd\x42\x70\x48\x76\xd8\xc9\xa9\xbd\x8e\x2d\x16\x82\x59\xd3\xdd" - "\xf7", - "test")) - throw std::runtime_error("bcrypt HMAC-SHA256 is broken"); - - Cvar_net_debug_atlas_packet = new ConVar( - "net_debug_atlas_packet", - "0", - FCVAR_NONE, - "Whether to log detailed debugging information for Atlas connectionless packets (warning: this allows unlimited amounts of " - "arbitrary data to be logged)"); - - Cvar_net_debug_atlas_packet_insecure = new ConVar( - "net_debug_atlas_packet_insecure", - "0", - FCVAR_NONE, - "Whether to disable signature verification for Atlas connectionless packets (DANGEROUS: this allows anyone to impersonate Atlas)"); -} diff --git a/primedev/server/serverpresence.cpp b/primedev/server/serverpresence.cpp index 099f6e648..eb21e39a3 100644 --- a/primedev/server/serverpresence.cpp +++ b/primedev/server/serverpresence.cpp @@ -178,6 +178,12 @@ void ServerPresenceManager::SetPort(const int iPort) m_ServerPresence.m_iPort = iPort; } +void ServerPresenceManager::SetAuthPort(const int iAuthPort) +{ + // update authport + m_ServerPresence.m_iAuthPort = iAuthPort; +} + void ServerPresenceManager::SetName(const std::string sServerNameUnicode) { // update name diff --git a/primedev/server/serverpresence.h b/primedev/server/serverpresence.h index 07c6fb551..c6f88a9d5 100644 --- a/primedev/server/serverpresence.h +++ b/primedev/server/serverpresence.h @@ -5,8 +5,7 @@ struct ServerPresence { public: int m_iPort; - - std::string m_sServerId; + int m_iAuthPort; std::string m_sServerName; std::string m_sServerDesc; @@ -24,8 +23,7 @@ struct ServerPresence ServerPresence(const ServerPresence* obj) { m_iPort = obj->m_iPort; - - m_sServerId = obj->m_sServerId; + m_iAuthPort = obj->m_iAuthPort; m_sServerName = obj->m_sServerName; m_sServerDesc = obj->m_sServerDesc; @@ -78,6 +76,7 @@ class ServerPresenceManager void RunFrame(double flCurrentTime); void SetPort(const int iPort); + void SetAuthPort(const int iPort); void SetName(const std::string sServerNameUnicode); void SetDescription(const std::string sServerDescUnicode); diff --git a/primedev/server/svm.cpp b/primedev/server/svm.cpp new file mode 100644 index 000000000..d4e027231 --- /dev/null +++ b/primedev/server/svm.cpp @@ -0,0 +1,3312 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "svm.h" +#ifdef _OPENMP +#include +#endif + +int libsvm_version = LIBSVM_VERSION; +typedef float Qfloat; +typedef signed char schar; +#ifndef min +template static inline T min(T x,T y) { return (x static inline T max(T x,T y) { return (x>y)?x:y; } +#endif +template static inline void swap(T& x, T& y) { T t=x; x=y; y=t; } +template static inline void clone(T*& dst, S* src, int n) +{ + dst = new T[n]; + memcpy((void *)dst,(void *)src,sizeof(T)*n); +} +static inline double powi(double base, int times) +{ + double tmp = base, ret = 1.0; + + for(int t=times; t>0; t/=2) + { + if(t%2==1) ret*=tmp; + tmp = tmp * tmp; + } + return ret; +} +#define INF HUGE_VAL +#define TAU 1e-12 +#define Malloc(type,n) (type *)malloc((n)*sizeof(type)) + +static void print_string_stdout(const char *s) +{ + fputs(s,stdout); + fflush(stdout); +} +static void (*svm_print_string) (const char *) = &print_string_stdout; +#if 1 +static void info(const char *fmt,...) +{ + char buf[BUFSIZ]; + va_list ap; + va_start(ap,fmt); + vsnprintf(buf,BUFSIZ,fmt,ap); + va_end(ap); + (*svm_print_string)(buf); +} +#else +static void info(const char *fmt,...) {} +#endif + +// +// Kernel Cache +// +// l is the number of total data items +// size is the cache size limit in bytes +// +class Cache +{ +public: + Cache(int l,long int size); + ~Cache(); + + // request data [0,len) + // return some position p where [p,len) need to be filled + // (p >= len if nothing needs to be filled) + int get_data(const int index, Qfloat **data, int len); + void swap_index(int i, int j); +private: + int l; + long int size; + struct head_t + { + head_t *prev, *next; // a circular list + Qfloat *data; + int len; // data[0,len) is cached in this entry + }; + + head_t *head; + head_t lru_head; + void lru_delete(head_t *h); + void lru_insert(head_t *h); +}; + +Cache::Cache(int l_,long int size_):l(l_),size(size_) +{ + head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0 + size /= sizeof(Qfloat); + size -= l * sizeof(head_t) / sizeof(Qfloat); + size = max(size, 2 * (long int) l); // cache must be large enough for two columns + lru_head.next = lru_head.prev = &lru_head; +} + +Cache::~Cache() +{ + for(head_t *h = lru_head.next; h != &lru_head; h=h->next) + free(h->data); + free(head); +} + +void Cache::lru_delete(head_t *h) +{ + // delete from current location + h->prev->next = h->next; + h->next->prev = h->prev; +} + +void Cache::lru_insert(head_t *h) +{ + // insert to last position + h->next = &lru_head; + h->prev = lru_head.prev; + h->prev->next = h; + h->next->prev = h; +} + +int Cache::get_data(const int index, Qfloat **data, int len) +{ + head_t *h = &head[index]; + if(h->len) lru_delete(h); + int more = len - h->len; + + if(more > 0) + { + // free old space + while(size < more) + { + head_t *old = lru_head.next; + lru_delete(old); + free(old->data); + size += old->len; + old->data = 0; + old->len = 0; + } + + // allocate new space + h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len); + size -= more; + swap(h->len,len); + } + + lru_insert(h); + *data = h->data; + return len; +} + +void Cache::swap_index(int i, int j) +{ + if(i==j) return; + + if(head[i].len) lru_delete(&head[i]); + if(head[j].len) lru_delete(&head[j]); + swap(head[i].data,head[j].data); + swap(head[i].len,head[j].len); + if(head[i].len) lru_insert(&head[i]); + if(head[j].len) lru_insert(&head[j]); + + if(i>j) swap(i,j); + for(head_t *h = lru_head.next; h!=&lru_head; h=h->next) + { + if(h->len > i) + { + if(h->len > j) + swap(h->data[i],h->data[j]); + else + { + // give up + lru_delete(h); + free(h->data); + size += h->len; + h->data = 0; + h->len = 0; + } + } + } +} + +// +// Kernel evaluation +// +// the static method k_function is for doing single kernel evaluation +// the constructor of Kernel prepares to calculate the l*l kernel matrix +// the member function get_Q is for getting one column from the Q Matrix +// +class QMatrix { +public: + virtual Qfloat *get_Q(int column, int len) const = 0; + virtual double *get_QD() const = 0; + virtual void swap_index(int i, int j) const = 0; + virtual ~QMatrix() {} +}; + +class Kernel: public QMatrix { +public: + Kernel(int l, svm_node * const * x, const svm_parameter& param); + virtual ~Kernel(); + + static double k_function(const svm_node *x, const svm_node *y, + const svm_parameter& param); + virtual Qfloat *get_Q(int column, int len) const = 0; + virtual double *get_QD() const = 0; + virtual void swap_index(int i, int j) const // no so const... + { + swap(x[i],x[j]); + if(x_square) swap(x_square[i],x_square[j]); + } +protected: + + double (Kernel::*kernel_function)(int i, int j) const; + +private: + const svm_node **x; + double *x_square; + + // svm_parameter + const int kernel_type; + const int degree; + const double gamma; + const double coef0; + + static double dot(const svm_node *px, const svm_node *py); + double kernel_linear(int i, int j) const + { + return dot(x[i],x[j]); + } + double kernel_poly(int i, int j) const + { + return powi(gamma*dot(x[i],x[j])+coef0,degree); + } + double kernel_rbf(int i, int j) const + { + return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j]))); + } + double kernel_sigmoid(int i, int j) const + { + return tanh(gamma*dot(x[i],x[j])+coef0); + } + double kernel_precomputed(int i, int j) const + { + return x[i][(int)(x[j][0].value)].value; + } +}; + +Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param) +:kernel_type(param.kernel_type), degree(param.degree), + gamma(param.gamma), coef0(param.coef0) +{ + switch(kernel_type) + { + case LINEAR: + kernel_function = &Kernel::kernel_linear; + break; + case POLY: + kernel_function = &Kernel::kernel_poly; + break; + case RBF: + kernel_function = &Kernel::kernel_rbf; + break; + case SIGMOID: + kernel_function = &Kernel::kernel_sigmoid; + break; + case PRECOMPUTED: + kernel_function = &Kernel::kernel_precomputed; + break; + } + + clone(x,x_,l); + + if(kernel_type == RBF) + { + x_square = new double[l]; + for(int i=0;iindex != -1 && py->index != -1) + { + if(px->index == py->index) + { + sum += px->value * py->value; + ++px; + ++py; + } + else + { + if(px->index > py->index) + ++py; + else + ++px; + } + } + return sum; +} + +double Kernel::k_function(const svm_node *x, const svm_node *y, + const svm_parameter& param) +{ + switch(param.kernel_type) + { + case LINEAR: + return dot(x,y); + case POLY: + return powi(param.gamma*dot(x,y)+param.coef0,param.degree); + case RBF: + { + double sum = 0; + while(x->index != -1 && y->index !=-1) + { + if(x->index == y->index) + { + double d = x->value - y->value; + sum += d*d; + ++x; + ++y; + } + else + { + if(x->index > y->index) + { + sum += y->value * y->value; + ++y; + } + else + { + sum += x->value * x->value; + ++x; + } + } + } + + while(x->index != -1) + { + sum += x->value * x->value; + ++x; + } + + while(y->index != -1) + { + sum += y->value * y->value; + ++y; + } + + return exp(-param.gamma*sum); + } + case SIGMOID: + return tanh(param.gamma*dot(x,y)+param.coef0); + case PRECOMPUTED: //x: test (validation), y: SV + return x[(int)(y->value)].value; + default: + return 0; // Unreachable + } +} + +// An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 +// Solves: +// +// min 0.5(\alpha^T Q \alpha) + p^T \alpha +// +// y^T \alpha = \delta +// y_i = +1 or -1 +// 0 <= alpha_i <= Cp for y_i = 1 +// 0 <= alpha_i <= Cn for y_i = -1 +// +// Given: +// +// Q, p, y, Cp, Cn, and an initial feasible point \alpha +// l is the size of vectors and matrices +// eps is the stopping tolerance +// +// solution will be put in \alpha, objective value will be put in obj +// +class Solver { +public: + Solver() {}; + virtual ~Solver() {}; + + struct SolutionInfo { + double obj; + double rho; + double upper_bound_p; + double upper_bound_n; + double r; // for Solver_NU + }; + + void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, + double *alpha_, double Cp, double Cn, double eps, + SolutionInfo* si, int shrinking); +protected: + int active_size; + schar *y; + double *G; // gradient of objective function + enum { LOWER_BOUND, UPPER_BOUND, FREE }; + char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE + double *alpha; + const QMatrix *Q; + const double *QD; + double eps; + double Cp,Cn; + double *p; + int *active_set; + double *G_bar; // gradient, if we treat free variables as 0 + int l; + bool unshrink; // XXX + + double get_C(int i) + { + return (y[i] > 0)? Cp : Cn; + } + void update_alpha_status(int i) + { + if(alpha[i] >= get_C(i)) + alpha_status[i] = UPPER_BOUND; + else if(alpha[i] <= 0) + alpha_status[i] = LOWER_BOUND; + else alpha_status[i] = FREE; + } + bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; } + bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; } + bool is_free(int i) { return alpha_status[i] == FREE; } + void swap_index(int i, int j); + void reconstruct_gradient(); + virtual int select_working_set(int &i, int &j); + virtual double calculate_rho(); + virtual void do_shrinking(); +private: + bool be_shrunk(int i, double Gmax1, double Gmax2); +}; + +void Solver::swap_index(int i, int j) +{ + Q->swap_index(i,j); + swap(y[i],y[j]); + swap(G[i],G[j]); + swap(alpha_status[i],alpha_status[j]); + swap(alpha[i],alpha[j]); + swap(p[i],p[j]); + swap(active_set[i],active_set[j]); + swap(G_bar[i],G_bar[j]); +} + +void Solver::reconstruct_gradient() +{ + // reconstruct inactive elements of G from G_bar and free variables + + if(active_size == l) return; + + int i,j; + int nr_free = 0; + + for(j=active_size;j 2*active_size*(l-active_size)) + { + for(i=active_size;iget_Q(i,active_size); + for(j=0;jget_Q(i,l); + double alpha_i = alpha[i]; + for(j=active_size;jl = l; + this->Q = &Q; + QD=Q.get_QD(); + clone(p, p_,l); + clone(y, y_,l); + clone(alpha,alpha_,l); + this->Cp = Cp; + this->Cn = Cn; + this->eps = eps; + unshrink = false; + + // initialize alpha_status + { + alpha_status = new char[l]; + for(int i=0;iINT_MAX/100 ? INT_MAX : 100*l); + int counter = min(l,1000)+1; + + while(iter < max_iter) + { + // show progress and do shrinking + + if(--counter == 0) + { + counter = min(l,1000); + if(shrinking) do_shrinking(); + info("."); + } + + int i,j; + if(select_working_set(i,j)!=0) + { + // reconstruct the whole gradient + reconstruct_gradient(); + // reset active set size and check + active_size = l; + info("*"); + if(select_working_set(i,j)!=0) + break; + else + counter = 1; // do shrinking next iteration + } + + ++iter; + + // update alpha[i] and alpha[j], handle bounds carefully + + const Qfloat *Q_i = Q.get_Q(i,active_size); + const Qfloat *Q_j = Q.get_Q(j,active_size); + + double C_i = get_C(i); + double C_j = get_C(j); + + double old_alpha_i = alpha[i]; + double old_alpha_j = alpha[j]; + + if(y[i]!=y[j]) + { + double quad_coef = QD[i]+QD[j]+2*Q_i[j]; + if (quad_coef <= 0) + quad_coef = TAU; + double delta = (-G[i]-G[j])/quad_coef; + double diff = alpha[i] - alpha[j]; + alpha[i] += delta; + alpha[j] += delta; + + if(diff > 0) + { + if(alpha[j] < 0) + { + alpha[j] = 0; + alpha[i] = diff; + } + } + else + { + if(alpha[i] < 0) + { + alpha[i] = 0; + alpha[j] = -diff; + } + } + if(diff > C_i - C_j) + { + if(alpha[i] > C_i) + { + alpha[i] = C_i; + alpha[j] = C_i - diff; + } + } + else + { + if(alpha[j] > C_j) + { + alpha[j] = C_j; + alpha[i] = C_j + diff; + } + } + } + else + { + double quad_coef = QD[i]+QD[j]-2*Q_i[j]; + if (quad_coef <= 0) + quad_coef = TAU; + double delta = (G[i]-G[j])/quad_coef; + double sum = alpha[i] + alpha[j]; + alpha[i] -= delta; + alpha[j] += delta; + + if(sum > C_i) + { + if(alpha[i] > C_i) + { + alpha[i] = C_i; + alpha[j] = sum - C_i; + } + } + else + { + if(alpha[j] < 0) + { + alpha[j] = 0; + alpha[i] = sum; + } + } + if(sum > C_j) + { + if(alpha[j] > C_j) + { + alpha[j] = C_j; + alpha[i] = sum - C_j; + } + } + else + { + if(alpha[i] < 0) + { + alpha[i] = 0; + alpha[j] = sum; + } + } + } + + // update G + + double delta_alpha_i = alpha[i] - old_alpha_i; + double delta_alpha_j = alpha[j] - old_alpha_j; + + for(int k=0;k= max_iter) + { + if(active_size < l) + { + // reconstruct the whole gradient to calculate objective value + reconstruct_gradient(); + active_size = l; + info("*"); + } + fprintf(stderr,"\nWARNING: reaching max number of iterations\n"); + } + + // calculate rho + + si->rho = calculate_rho(); + + // calculate objective value + { + double v = 0; + int i; + for(i=0;iobj = v/2; + } + + // put back the solution + { + for(int i=0;iupper_bound_p = Cp; + si->upper_bound_n = Cn; + + info("\noptimization finished, #iter = %d\n",iter); + + delete[] p; + delete[] y; + delete[] alpha; + delete[] alpha_status; + delete[] active_set; + delete[] G; + delete[] G_bar; +} + +// return 1 if already optimal, return 0 otherwise +int Solver::select_working_set(int &out_i, int &out_j) +{ + // return i,j such that + // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) + // j: minimizes the decrease of obj value + // (if quadratic coefficeint <= 0, replace it with tau) + // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) + + double Gmax = -INF; + double Gmax2 = -INF; + int Gmax_idx = -1; + int Gmin_idx = -1; + double obj_diff_min = INF; + + for(int t=0;t= Gmax) + { + Gmax = -G[t]; + Gmax_idx = t; + } + } + else + { + if(!is_lower_bound(t)) + if(G[t] >= Gmax) + { + Gmax = G[t]; + Gmax_idx = t; + } + } + + int i = Gmax_idx; + const Qfloat *Q_i = NULL; + if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1 + Q_i = Q->get_Q(i,active_size); + + for(int j=0;j= Gmax2) + Gmax2 = G[j]; + if (grad_diff > 0) + { + double obj_diff; + double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; + if (quad_coef > 0) + obj_diff = -(grad_diff*grad_diff)/quad_coef; + else + obj_diff = -(grad_diff*grad_diff)/TAU; + + if (obj_diff <= obj_diff_min) + { + Gmin_idx=j; + obj_diff_min = obj_diff; + } + } + } + } + else + { + if (!is_upper_bound(j)) + { + double grad_diff= Gmax-G[j]; + if (-G[j] >= Gmax2) + Gmax2 = -G[j]; + if (grad_diff > 0) + { + double obj_diff; + double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; + if (quad_coef > 0) + obj_diff = -(grad_diff*grad_diff)/quad_coef; + else + obj_diff = -(grad_diff*grad_diff)/TAU; + + if (obj_diff <= obj_diff_min) + { + Gmin_idx=j; + obj_diff_min = obj_diff; + } + } + } + } + } + + if(Gmax+Gmax2 < eps || Gmin_idx == -1) + return 1; + + out_i = Gmax_idx; + out_j = Gmin_idx; + return 0; +} + +bool Solver::be_shrunk(int i, double Gmax1, double Gmax2) +{ + if(is_upper_bound(i)) + { + if(y[i]==+1) + return(-G[i] > Gmax1); + else + return(-G[i] > Gmax2); + } + else if(is_lower_bound(i)) + { + if(y[i]==+1) + return(G[i] > Gmax2); + else + return(G[i] > Gmax1); + } + else + return(false); +} + +void Solver::do_shrinking() +{ + int i; + double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) } + double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) } + + // find maximal violating pair first + for(i=0;i= Gmax1) + Gmax1 = -G[i]; + } + if(!is_lower_bound(i)) + { + if(G[i] >= Gmax2) + Gmax2 = G[i]; + } + } + else + { + if(!is_upper_bound(i)) + { + if(-G[i] >= Gmax2) + Gmax2 = -G[i]; + } + if(!is_lower_bound(i)) + { + if(G[i] >= Gmax1) + Gmax1 = G[i]; + } + } + } + + if(unshrink == false && Gmax1 + Gmax2 <= eps*10) + { + unshrink = true; + reconstruct_gradient(); + active_size = l; + info("*"); + } + + for(i=0;i i) + { + if (!be_shrunk(active_size, Gmax1, Gmax2)) + { + swap_index(i,active_size); + break; + } + active_size--; + } + } +} + +double Solver::calculate_rho() +{ + double r; + int nr_free = 0; + double ub = INF, lb = -INF, sum_free = 0; + for(int i=0;i0) + r = sum_free/nr_free; + else + r = (ub+lb)/2; + + return r; +} + +// +// Solver for nu-svm classification and regression +// +// additional constraint: e^T \alpha = constant +// +class Solver_NU: public Solver +{ +public: + Solver_NU() {} + void Solve(int l, const QMatrix& Q, const double *p, const schar *y, + double *alpha, double Cp, double Cn, double eps, + SolutionInfo* si, int shrinking) + { + this->si = si; + Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking); + } +private: + SolutionInfo *si; + int select_working_set(int &i, int &j); + double calculate_rho(); + bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4); + void do_shrinking(); +}; + +// return 1 if already optimal, return 0 otherwise +int Solver_NU::select_working_set(int &out_i, int &out_j) +{ + // return i,j such that y_i = y_j and + // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) + // j: minimizes the decrease of obj value + // (if quadratic coefficeint <= 0, replace it with tau) + // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) + + double Gmaxp = -INF; + double Gmaxp2 = -INF; + int Gmaxp_idx = -1; + + double Gmaxn = -INF; + double Gmaxn2 = -INF; + int Gmaxn_idx = -1; + + int Gmin_idx = -1; + double obj_diff_min = INF; + + for(int t=0;t= Gmaxp) + { + Gmaxp = -G[t]; + Gmaxp_idx = t; + } + } + else + { + if(!is_lower_bound(t)) + if(G[t] >= Gmaxn) + { + Gmaxn = G[t]; + Gmaxn_idx = t; + } + } + + int ip = Gmaxp_idx; + int in = Gmaxn_idx; + const Qfloat *Q_ip = NULL; + const Qfloat *Q_in = NULL; + if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1 + Q_ip = Q->get_Q(ip,active_size); + if(in != -1) + Q_in = Q->get_Q(in,active_size); + + for(int j=0;j= Gmaxp2) + Gmaxp2 = G[j]; + if (grad_diff > 0) + { + double obj_diff; + double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; + if (quad_coef > 0) + obj_diff = -(grad_diff*grad_diff)/quad_coef; + else + obj_diff = -(grad_diff*grad_diff)/TAU; + + if (obj_diff <= obj_diff_min) + { + Gmin_idx=j; + obj_diff_min = obj_diff; + } + } + } + } + else + { + if (!is_upper_bound(j)) + { + double grad_diff=Gmaxn-G[j]; + if (-G[j] >= Gmaxn2) + Gmaxn2 = -G[j]; + if (grad_diff > 0) + { + double obj_diff; + double quad_coef = QD[in]+QD[j]-2*Q_in[j]; + if (quad_coef > 0) + obj_diff = -(grad_diff*grad_diff)/quad_coef; + else + obj_diff = -(grad_diff*grad_diff)/TAU; + + if (obj_diff <= obj_diff_min) + { + Gmin_idx=j; + obj_diff_min = obj_diff; + } + } + } + } + } + + if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps || Gmin_idx == -1) + return 1; + + if (y[Gmin_idx] == +1) + out_i = Gmaxp_idx; + else + out_i = Gmaxn_idx; + out_j = Gmin_idx; + + return 0; +} + +bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4) +{ + if(is_upper_bound(i)) + { + if(y[i]==+1) + return(-G[i] > Gmax1); + else + return(-G[i] > Gmax4); + } + else if(is_lower_bound(i)) + { + if(y[i]==+1) + return(G[i] > Gmax2); + else + return(G[i] > Gmax3); + } + else + return(false); +} + +void Solver_NU::do_shrinking() +{ + double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) } + double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) } + double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) } + double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) } + + // find maximal violating pair first + int i; + for(i=0;i Gmax1) Gmax1 = -G[i]; + } + else if(-G[i] > Gmax4) Gmax4 = -G[i]; + } + if(!is_lower_bound(i)) + { + if(y[i]==+1) + { + if(G[i] > Gmax2) Gmax2 = G[i]; + } + else if(G[i] > Gmax3) Gmax3 = G[i]; + } + } + + if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) + { + unshrink = true; + reconstruct_gradient(); + active_size = l; + } + + for(i=0;i i) + { + if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4)) + { + swap_index(i,active_size); + break; + } + active_size--; + } + } +} + +double Solver_NU::calculate_rho() +{ + int nr_free1 = 0,nr_free2 = 0; + double ub1 = INF, ub2 = INF; + double lb1 = -INF, lb2 = -INF; + double sum_free1 = 0, sum_free2 = 0; + + for(int i=0;i 0) + r1 = sum_free1/nr_free1; + else + r1 = (ub1+lb1)/2; + + if(nr_free2 > 0) + r2 = sum_free2/nr_free2; + else + r2 = (ub2+lb2)/2; + + si->r = (r1+r2)/2; + return (r1-r2)/2; +} + +// +// Q matrices for various formulations +// +class SVC_Q: public Kernel +{ +public: + SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_) + :Kernel(prob.l, prob.x, param) + { + clone(y,y_,prob.l); + cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); + QD = new double[prob.l]; + for(int i=0;i*kernel_function)(i,i); + } + + Qfloat *get_Q(int i, int len) const + { + Qfloat *data; + int start, j; + if((start = cache->get_data(i,&data,len)) < len) + { +#ifdef _OPENMP +#pragma omp parallel for private(j) schedule(guided) +#endif + for(j=start;j*kernel_function)(i,j)); + } + return data; + } + + double *get_QD() const + { + return QD; + } + + void swap_index(int i, int j) const + { + cache->swap_index(i,j); + Kernel::swap_index(i,j); + swap(y[i],y[j]); + swap(QD[i],QD[j]); + } + + ~SVC_Q() + { + delete[] y; + delete cache; + delete[] QD; + } +private: + schar *y; + Cache *cache; + double *QD; +}; + +class ONE_CLASS_Q: public Kernel +{ +public: + ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param) + :Kernel(prob.l, prob.x, param) + { + cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); + QD = new double[prob.l]; + for(int i=0;i*kernel_function)(i,i); + } + + Qfloat *get_Q(int i, int len) const + { + Qfloat *data; + int start, j; + if((start = cache->get_data(i,&data,len)) < len) + { + for(j=start;j*kernel_function)(i,j); + } + return data; + } + + double *get_QD() const + { + return QD; + } + + void swap_index(int i, int j) const + { + cache->swap_index(i,j); + Kernel::swap_index(i,j); + swap(QD[i],QD[j]); + } + + ~ONE_CLASS_Q() + { + delete cache; + delete[] QD; + } +private: + Cache *cache; + double *QD; +}; + +class SVR_Q: public Kernel +{ +public: + SVR_Q(const svm_problem& prob, const svm_parameter& param) + :Kernel(prob.l, prob.x, param) + { + l = prob.l; + cache = new Cache(l,(long int)(param.cache_size*(1<<20))); + QD = new double[2*l]; + sign = new schar[2*l]; + index = new int[2*l]; + for(int k=0;k*kernel_function)(k,k); + QD[k+l] = QD[k]; + } + buffer[0] = new Qfloat[2*l]; + buffer[1] = new Qfloat[2*l]; + next_buffer = 0; + } + + void swap_index(int i, int j) const + { + swap(sign[i],sign[j]); + swap(index[i],index[j]); + swap(QD[i],QD[j]); + } + + Qfloat *get_Q(int i, int len) const + { + Qfloat *data; + int j, real_i = index[i]; + if(cache->get_data(real_i,&data,l) < l) + { +#ifdef _OPENMP +#pragma omp parallel for private(j) schedule(guided) +#endif + for(j=0;j*kernel_function)(real_i,j); + } + + // reorder and copy + Qfloat *buf = buffer[next_buffer]; + next_buffer = 1 - next_buffer; + schar si = sign[i]; + for(j=0;jl; + double *minus_ones = new double[l]; + schar *y = new schar[l]; + + int i; + + for(i=0;iy[i] > 0) y[i] = +1; else y[i] = -1; + } + + Solver s; + s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, + alpha, Cp, Cn, param->eps, si, param->shrinking); + + double sum_alpha=0; + for(i=0;il)); + + for(i=0;il; + double nu = param->nu; + + schar *y = new schar[l]; + + for(i=0;iy[i]>0) + y[i] = +1; + else + y[i] = -1; + + double sum_pos = nu*l/2; + double sum_neg = nu*l/2; + + for(i=0;ieps, si, param->shrinking); + double r = si->r; + + info("C = %f\n",1/r); + + for(i=0;irho /= r; + si->obj /= (r*r); + si->upper_bound_p = 1/r; + si->upper_bound_n = 1/r; + + delete[] y; + delete[] zeros; +} + +static void solve_one_class( + const svm_problem *prob, const svm_parameter *param, + double *alpha, Solver::SolutionInfo* si) +{ + int l = prob->l; + double *zeros = new double[l]; + schar *ones = new schar[l]; + int i; + + int n = (int)(param->nu*prob->l); // # of alpha's at upper bound + + for(i=0;il) + alpha[n] = param->nu * prob->l - n; + for(i=n+1;ieps, si, param->shrinking); + + delete[] zeros; + delete[] ones; +} + +static void solve_epsilon_svr( + const svm_problem *prob, const svm_parameter *param, + double *alpha, Solver::SolutionInfo* si) +{ + int l = prob->l; + double *alpha2 = new double[2*l]; + double *linear_term = new double[2*l]; + schar *y = new schar[2*l]; + int i; + + for(i=0;ip - prob->y[i]; + y[i] = 1; + + alpha2[i+l] = 0; + linear_term[i+l] = param->p + prob->y[i]; + y[i+l] = -1; + } + + Solver s; + s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y, + alpha2, param->C, param->C, param->eps, si, param->shrinking); + + double sum_alpha = 0; + for(i=0;iC*l)); + + delete[] alpha2; + delete[] linear_term; + delete[] y; +} + +static void solve_nu_svr( + const svm_problem *prob, const svm_parameter *param, + double *alpha, Solver::SolutionInfo* si) +{ + int l = prob->l; + double C = param->C; + double *alpha2 = new double[2*l]; + double *linear_term = new double[2*l]; + schar *y = new schar[2*l]; + int i; + + double sum = C * param->nu * l / 2; + for(i=0;iy[i]; + y[i] = 1; + + linear_term[i+l] = prob->y[i]; + y[i+l] = -1; + } + + Solver_NU s; + s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y, + alpha2, C, C, param->eps, si, param->shrinking); + + info("epsilon = %f\n",-si->r); + + for(i=0;il); + Solver::SolutionInfo si; + switch(param->svm_type) + { + case C_SVC: + solve_c_svc(prob,param,alpha,&si,Cp,Cn); + break; + case NU_SVC: + solve_nu_svc(prob,param,alpha,&si); + break; + case ONE_CLASS: + solve_one_class(prob,param,alpha,&si); + break; + case EPSILON_SVR: + solve_epsilon_svr(prob,param,alpha,&si); + break; + case NU_SVR: + solve_nu_svr(prob,param,alpha,&si); + break; + } + + info("obj = %f, rho = %f\n",si.obj,si.rho); + + // output SVs + + int nSV = 0; + int nBSV = 0; + for(int i=0;il;i++) + { + if(fabs(alpha[i]) > 0) + { + ++nSV; + if(prob->y[i] > 0) + { + if(fabs(alpha[i]) >= si.upper_bound_p) + ++nBSV; + } + else + { + if(fabs(alpha[i]) >= si.upper_bound_n) + ++nBSV; + } + } + } + + info("nSV = %d, nBSV = %d\n",nSV,nBSV); + + decision_function f; + f.alpha = alpha; + f.rho = si.rho; + return f; +} + +// Platt's binary SVM Probablistic Output: an improvement from Lin et al. +static void sigmoid_train( + int l, const double *dec_values, const double *labels, + double& A, double& B) +{ + double prior1=0, prior0 = 0; + int i; + + for (i=0;i 0) prior1+=1; + else prior0+=1; + + int max_iter=100; // Maximal number of iterations + double min_step=1e-10; // Minimal step taken in line search + double sigma=1e-12; // For numerically strict PD of Hessian + double eps=1e-5; + double hiTarget=(prior1+1.0)/(prior1+2.0); + double loTarget=1/(prior0+2.0); + double *t=Malloc(double,l); + double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; + double newA,newB,newf,d1,d2; + int iter; + + // Initial Point and Initial Fun Value + A=0.0; B=log((prior0+1.0)/(prior1+1.0)); + double fval = 0.0; + + for (i=0;i0) t[i]=hiTarget; + else t[i]=loTarget; + fApB = dec_values[i]*A+B; + if (fApB>=0) + fval += t[i]*fApB + log(1+exp(-fApB)); + else + fval += (t[i] - 1)*fApB +log(1+exp(fApB)); + } + for (iter=0;iter= 0) + { + p=exp(-fApB)/(1.0+exp(-fApB)); + q=1.0/(1.0+exp(-fApB)); + } + else + { + p=1.0/(1.0+exp(fApB)); + q=exp(fApB)/(1.0+exp(fApB)); + } + d2=p*q; + h11+=dec_values[i]*dec_values[i]*d2; + h22+=d2; + h21+=dec_values[i]*d2; + d1=t[i]-p; + g1+=dec_values[i]*d1; + g2+=d1; + } + + // Stopping Criteria + if (fabs(g1)= min_step) + { + newA = A + stepsize * dA; + newB = B + stepsize * dB; + + // New function value + newf = 0.0; + for (i=0;i= 0) + newf += t[i]*fApB + log(1+exp(-fApB)); + else + newf += (t[i] - 1)*fApB +log(1+exp(fApB)); + } + // Check sufficient decrease + if (newf=max_iter) + info("Reaching maximal iterations in two-class probability estimates\n"); + free(t); +} + +static double sigmoid_predict(double decision_value, double A, double B) +{ + double fApB = decision_value*A+B; + // 1-p used later; avoid catastrophic cancellation + if (fApB >= 0) + return exp(-fApB)/(1.0+exp(-fApB)); + else + return 1.0/(1+exp(fApB)) ; +} + +// Method 2 from the multiclass_prob paper by Wu, Lin, and Weng to predict probabilities +static void multiclass_probability(int k, double **r, double *p) +{ + int t,j; + int iter = 0, max_iter=max(100,k); + double **Q=Malloc(double *,k); + double *Qp=Malloc(double,k); + double pQp, eps=0.005/k; + + for (t=0;tmax_error) + max_error=error; + } + if (max_error=max_iter) + info("Exceeds max_iter in multiclass_prob\n"); + for(t=0;tl); + double *dec_values = Malloc(double,prob->l); + + // random shuffle + for(i=0;il;i++) perm[i]=i; + for(i=0;il;i++) + { + int j = i+rand()%(prob->l-i); + swap(perm[i],perm[j]); + } + for(i=0;il/nr_fold; + int end = (i+1)*prob->l/nr_fold; + int j,k; + struct svm_problem subprob; + + subprob.l = prob->l-(end-begin); + subprob.x = Malloc(struct svm_node*,subprob.l); + subprob.y = Malloc(double,subprob.l); + + k=0; + for(j=0;jx[perm[j]]; + subprob.y[k] = prob->y[perm[j]]; + ++k; + } + for(j=end;jl;j++) + { + subprob.x[k] = prob->x[perm[j]]; + subprob.y[k] = prob->y[perm[j]]; + ++k; + } + int p_count=0,n_count=0; + for(j=0;j0) + p_count++; + else + n_count++; + + if(p_count==0 && n_count==0) + for(j=begin;j 0 && n_count == 0) + for(j=begin;j 0) + for(j=begin;jx[perm[j]],&(dec_values[perm[j]])); + // ensure +1 -1 order; reason not using CV subroutine + dec_values[perm[j]] *= submodel->label[0]; + } + svm_free_and_destroy_model(&submodel); + svm_destroy_param(&subparam); + } + free(subprob.x); + free(subprob.y); + } + sigmoid_train(prob->l,dec_values,prob->y,probA,probB); + free(dec_values); + free(perm); +} + +// Binning method from the oneclass_prob paper by Que and Lin to predict the probability as a normal instance (i.e., not an outlier) +static double predict_one_class_probability(const svm_model *model, double dec_value) +{ + double prob_estimate = 0.0; + int nr_marks = 10; + + if(dec_value < model->prob_density_marks[0]) + prob_estimate = 0.001; + else if(dec_value > model->prob_density_marks[nr_marks-1]) + prob_estimate = 0.999; + else + { + for(int i=1;iprob_density_marks[i]) + { + prob_estimate = (double)i/nr_marks; + break; + } + } + return prob_estimate; +} + +static int compare_double(const void *a, const void *b) +{ + if(*(double *)a > *(double *)b) + return 1; + else if(*(double *)a < *(double *)b) + return -1; + return 0; +} + +// Get parameters for one-class SVM probability estimates +static int svm_one_class_probability(const svm_problem *prob, const svm_model *model, double *prob_density_marks) +{ + double *dec_values = Malloc(double,prob->l); + double *pred_results = Malloc(double,prob->l); + int ret = 0; + int nr_marks = 10; + + for(int i=0;il;i++) + pred_results[i] = svm_predict_values(model,prob->x[i],&dec_values[i]); + qsort(dec_values,prob->l,sizeof(double),compare_double); + + int neg_counter=0; + for(int i=0;il;i++) + if(dec_values[i]>=0) + { + neg_counter = i; + break; + } + + int pos_counter = prob->l-neg_counter; + if(neg_counterl); + double mae = 0; + + svm_parameter newparam = *param; + newparam.probability = 0; + svm_cross_validation(prob,&newparam,nr_fold,ymv); + for(i=0;il;i++) + { + ymv[i]=prob->y[i]-ymv[i]; + mae += fabs(ymv[i]); + } + mae /= prob->l; + double std=sqrt(2*mae*mae); + int count=0; + mae=0; + for(i=0;il;i++) + if (fabs(ymv[i]) > 5*std) + count=count+1; + else + mae+=fabs(ymv[i]); + mae /= (prob->l-count); + info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae); + free(ymv); + return mae; +} + + +// label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data +// perm, length l, must be allocated before calling this subroutine +static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm) +{ + int l = prob->l; + int max_nr_class = 16; + int nr_class = 0; + int *label = Malloc(int,max_nr_class); + int *count = Malloc(int,max_nr_class); + int *data_label = Malloc(int,l); + int i; + + for(i=0;iy[i]; + int j; + for(j=0;jparam = *param; + model->free_sv = 0; // XXX + + if(param->svm_type == ONE_CLASS || + param->svm_type == EPSILON_SVR || + param->svm_type == NU_SVR) + { + // regression or one-class-svm + model->nr_class = 2; + model->label = NULL; + model->nSV = NULL; + model->probA = NULL; model->probB = NULL; + model->prob_density_marks = NULL; + model->sv_coef = Malloc(double *,1); + + decision_function f = svm_train_one(prob,param,0,0); + model->rho = Malloc(double,1); + model->rho[0] = f.rho; + + int nSV = 0; + int i; + for(i=0;il;i++) + if(fabs(f.alpha[i]) > 0) ++nSV; + model->l = nSV; + model->SV = Malloc(svm_node *,nSV); + model->sv_coef[0] = Malloc(double,nSV); + model->sv_indices = Malloc(int,nSV); + int j = 0; + for(i=0;il;i++) + if(fabs(f.alpha[i]) > 0) + { + model->SV[j] = prob->x[i]; + model->sv_coef[0][j] = f.alpha[i]; + model->sv_indices[j] = i+1; + ++j; + } + + if(param->probability && + (param->svm_type == EPSILON_SVR || + param->svm_type == NU_SVR)) + { + model->probA = Malloc(double,1); + model->probA[0] = svm_svr_probability(prob,param); + } + else if(param->probability && param->svm_type == ONE_CLASS) + { + int nr_marks = 10; + double *prob_density_marks = Malloc(double,nr_marks); + + if(svm_one_class_probability(prob,model,prob_density_marks) == 0) + model->prob_density_marks = prob_density_marks; + else + free(prob_density_marks); + } + + free(f.alpha); + } + else + { + // classification + int l = prob->l; + int nr_class; + int *label = NULL; + int *start = NULL; + int *count = NULL; + int *perm = Malloc(int,l); + + // group training data of the same class + svm_group_classes(prob,&nr_class,&label,&start,&count,perm); + if(nr_class == 1) + info("WARNING: training data in only one class. See README for details.\n"); + + svm_node **x = Malloc(svm_node *,l); + int i; + for(i=0;ix[perm[i]]; + + // calculate weighted C + + double *weighted_C = Malloc(double, nr_class); + for(i=0;iC; + for(i=0;inr_weight;i++) + { + int j; + for(j=0;jweight_label[i] == label[j]) + break; + if(j == nr_class) + fprintf(stderr,"WARNING: class label %d specified in weight is not found\n", param->weight_label[i]); + else + weighted_C[j] *= param->weight[i]; + } + + // train k*(k-1)/2 models + + bool *nonzero = Malloc(bool,l); + for(i=0;iprobability) + { + probA=Malloc(double,nr_class*(nr_class-1)/2); + probB=Malloc(double,nr_class*(nr_class-1)/2); + } + + int p = 0; + for(i=0;iprobability) + svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]); + + f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]); + for(k=0;k 0) + nonzero[si+k] = true; + for(k=0;k 0) + nonzero[sj+k] = true; + free(sub_prob.x); + free(sub_prob.y); + ++p; + } + + // build output + + model->nr_class = nr_class; + + model->label = Malloc(int,nr_class); + for(i=0;ilabel[i] = label[i]; + + model->rho = Malloc(double,nr_class*(nr_class-1)/2); + for(i=0;irho[i] = f[i].rho; + + if(param->probability) + { + model->probA = Malloc(double,nr_class*(nr_class-1)/2); + model->probB = Malloc(double,nr_class*(nr_class-1)/2); + for(i=0;iprobA[i] = probA[i]; + model->probB[i] = probB[i]; + } + } + else + { + model->probA=NULL; + model->probB=NULL; + } + model->prob_density_marks=NULL; // for one-class SVM probabilistic outputs only + + int total_sv = 0; + int *nz_count = Malloc(int,nr_class); + model->nSV = Malloc(int,nr_class); + for(i=0;inSV[i] = nSV; + nz_count[i] = nSV; + } + + info("Total nSV = %d\n",total_sv); + + model->l = total_sv; + model->SV = Malloc(svm_node *,total_sv); + model->sv_indices = Malloc(int,total_sv); + p = 0; + for(i=0;iSV[p] = x[i]; + model->sv_indices[p++] = perm[i] + 1; + } + + int *nz_start = Malloc(int,nr_class); + nz_start[0] = 0; + for(i=1;isv_coef = Malloc(double *,nr_class-1); + for(i=0;isv_coef[i] = Malloc(double,total_sv); + + p = 0; + for(i=0;isv_coef[j-1][q++] = f[p].alpha[k]; + q = nz_start[j]; + for(k=0;ksv_coef[i][q++] = f[p].alpha[ci+k]; + ++p; + } + + free(label); + free(probA); + free(probB); + free(count); + free(perm); + free(start); + free(x); + free(weighted_C); + free(nonzero); + for(i=0;il; + int *perm = Malloc(int,l); + int nr_class; + if (nr_fold > l) + { + fprintf(stderr,"WARNING: # folds (%d) > # data (%d). Will use # folds = # data instead (i.e., leave-one-out cross validation)\n", nr_fold, l); + nr_fold = l; + } + fold_start = Malloc(int,nr_fold+1); + // stratified cv may not give leave-one-out rate + // Each class to l folds -> some folds may have zero elements + if((param->svm_type == C_SVC || + param->svm_type == NU_SVC) && nr_fold < l) + { + int *start = NULL; + int *label = NULL; + int *count = NULL; + svm_group_classes(prob,&nr_class,&label,&start,&count,perm); + + // random shuffle and then data grouped by fold using the array perm + int *fold_count = Malloc(int,nr_fold); + int c; + int *index = Malloc(int,l); + for(i=0;ix[perm[j]]; + subprob.y[k] = prob->y[perm[j]]; + ++k; + } + for(j=end;jx[perm[j]]; + subprob.y[k] = prob->y[perm[j]]; + ++k; + } + struct svm_model *submodel = svm_train(&subprob,param); + if(param->probability && + (param->svm_type == C_SVC || param->svm_type == NU_SVC)) + { + double *prob_estimates=Malloc(double,svm_get_nr_class(submodel)); + for(j=begin;jx[perm[j]],prob_estimates); + free(prob_estimates); + } + else + for(j=begin;jx[perm[j]]); + svm_free_and_destroy_model(&submodel); + free(subprob.x); + free(subprob.y); + } + free(fold_start); + free(perm); +} + + +int svm_get_svm_type(const svm_model *model) +{ + return model->param.svm_type; +} + +int svm_get_nr_class(const svm_model *model) +{ + return model->nr_class; +} + +void svm_get_labels(const svm_model *model, int* label) +{ + if (model->label != NULL) + for(int i=0;inr_class;i++) + label[i] = model->label[i]; +} + +void svm_get_sv_indices(const svm_model *model, int* indices) +{ + if (model->sv_indices != NULL) + for(int i=0;il;i++) + indices[i] = model->sv_indices[i]; +} + +int svm_get_nr_sv(const svm_model *model) +{ + return model->l; +} + +double svm_get_svr_probability(const svm_model *model) +{ + if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && + model->probA!=NULL) + return model->probA[0]; + else + { + fprintf(stderr,"Model doesn't contain information for SVR probability inference\n"); + return 0; + } +} + +double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values) +{ + int i; + if(model->param.svm_type == ONE_CLASS || + model->param.svm_type == EPSILON_SVR || + model->param.svm_type == NU_SVR) + { + double *sv_coef = model->sv_coef[0]; + double sum = 0; +#ifdef _OPENMP +#pragma omp parallel for private(i) reduction(+:sum) schedule(guided) +#endif + for(i=0;il;i++) + sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param); + sum -= model->rho[0]; + *dec_values = sum; + + if(model->param.svm_type == ONE_CLASS) + return (sum>0)?1:-1; + else + return sum; + } + else + { + int nr_class = model->nr_class; + int l = model->l; + + double *kvalue = Malloc(double,l); +#ifdef _OPENMP +#pragma omp parallel for private(i) schedule(guided) +#endif + for(i=0;iSV[i],model->param); + + int *start = Malloc(int,nr_class); + start[0] = 0; + for(i=1;inSV[i-1]; + + int *vote = Malloc(int,nr_class); + for(i=0;inSV[i]; + int cj = model->nSV[j]; + + int k; + double *coef1 = model->sv_coef[j-1]; + double *coef2 = model->sv_coef[i]; + for(k=0;krho[p]; + dec_values[p] = sum; + + if(dec_values[p] > 0) + ++vote[i]; + else + ++vote[j]; + p++; + } + + int vote_max_idx = 0; + for(i=1;i vote[vote_max_idx]) + vote_max_idx = i; + + free(kvalue); + free(start); + free(vote); + return model->label[vote_max_idx]; + } +} + +double svm_predict(const svm_model *model, const svm_node *x) +{ + int nr_class = model->nr_class; + double *dec_values; + if(model->param.svm_type == ONE_CLASS || + model->param.svm_type == EPSILON_SVR || + model->param.svm_type == NU_SVR) + dec_values = Malloc(double, 1); + else + dec_values = Malloc(double, nr_class*(nr_class-1)/2); + double pred_result = svm_predict_values(model, x, dec_values); + free(dec_values); + return pred_result; +} + +double svm_predict_probability( + const svm_model *model, const svm_node *x, double *prob_estimates) +{ + if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && + model->probA!=NULL && model->probB!=NULL) + { + int i; + int nr_class = model->nr_class; + double *dec_values = Malloc(double, nr_class*(nr_class-1)/2); + svm_predict_values(model, x, dec_values); + + double min_prob=1e-7; + double **pairwise_prob=Malloc(double *,nr_class); + for(i=0;iprobA[k],model->probB[k]),min_prob),1-min_prob); + pairwise_prob[j][i]=1-pairwise_prob[i][j]; + k++; + } + if (nr_class == 2) + { + prob_estimates[0] = pairwise_prob[0][1]; + prob_estimates[1] = pairwise_prob[1][0]; + } + else + multiclass_probability(nr_class,pairwise_prob,prob_estimates); + + int prob_max_idx = 0; + for(i=1;i prob_estimates[prob_max_idx]) + prob_max_idx = i; + for(i=0;ilabel[prob_max_idx]; + } + else if(model->param.svm_type == ONE_CLASS && model->prob_density_marks!=NULL) + { + double dec_value; + double pred_result = svm_predict_values(model,x,&dec_value); + prob_estimates[0] = predict_one_class_probability(model,dec_value); + prob_estimates[1] = 1-prob_estimates[0]; + return pred_result; + } + else + return svm_predict(model, x); +} + +static const char *svm_type_table[] = +{ + "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL +}; + +static const char *kernel_type_table[]= +{ + "linear","polynomial","rbf","sigmoid","precomputed",NULL +}; + +int svm_save_model(const char *model_file_name, const svm_model *model) +{ + FILE *fp = fopen(model_file_name,"w"); + if(fp==NULL) return -1; + + char *old_locale = setlocale(LC_ALL, NULL); + if (old_locale) { + old_locale = strdup(old_locale); + } + setlocale(LC_ALL, "C"); + + const svm_parameter& param = model->param; + + fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]); + fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]); + + if(param.kernel_type == POLY) + fprintf(fp,"degree %d\n", param.degree); + + if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID) + fprintf(fp,"gamma %.17g\n", param.gamma); + + if(param.kernel_type == POLY || param.kernel_type == SIGMOID) + fprintf(fp,"coef0 %.17g\n", param.coef0); + + int nr_class = model->nr_class; + int l = model->l; + fprintf(fp, "nr_class %d\n", nr_class); + fprintf(fp, "total_sv %d\n",l); + + { + fprintf(fp, "rho"); + for(int i=0;irho[i]); + fprintf(fp, "\n"); + } + + if(model->label) + { + fprintf(fp, "label"); + for(int i=0;ilabel[i]); + fprintf(fp, "\n"); + } + + if(model->probA) // regression has probA only + { + fprintf(fp, "probA"); + for(int i=0;iprobA[i]); + fprintf(fp, "\n"); + } + if(model->probB) + { + fprintf(fp, "probB"); + for(int i=0;iprobB[i]); + fprintf(fp, "\n"); + } + if(model->prob_density_marks) + { + fprintf(fp, "prob_density_marks"); + int nr_marks=10; + for(int i=0;iprob_density_marks[i]); + fprintf(fp, "\n"); + } + + if(model->nSV) + { + fprintf(fp, "nr_sv"); + for(int i=0;inSV[i]); + fprintf(fp, "\n"); + } + + fprintf(fp, "SV\n"); + const double * const *sv_coef = model->sv_coef; + const svm_node * const *SV = model->SV; + + for(int i=0;ivalue)); + else + while(p->index != -1) + { + fprintf(fp,"%d:%.8g ",p->index,p->value); + p++; + } + fprintf(fp, "\n"); + } + + setlocale(LC_ALL, old_locale); + free(old_locale); + + if (ferror(fp) != 0 || fclose(fp) != 0) return -1; + else return 0; +} + +static char *line = NULL; +static int max_line_len; + +static char* readline(FILE *input) +{ + int len; + + if(fgets(line,max_line_len,input) == NULL) + return NULL; + + while(strrchr(line,'\n') == NULL) + { + max_line_len *= 2; + line = (char *) realloc(line,max_line_len); + len = (int) strlen(line); + if(fgets(line+len,max_line_len-len,input) == NULL) + break; + } + return line; +} + +// +// FSCANF helps to handle fscanf failures. +// Its do-while block avoids the ambiguity when +// if (...) +// FSCANF(); +// is used +// +#define FSCANF(_stream, _format, _var) do{ if (fscanf(_stream, _format, _var) != 1) return false; }while(0) +bool read_model_header(FILE *fp, svm_model* model) +{ + svm_parameter& param = model->param; + // parameters for training only won't be assigned, but arrays are assigned as NULL for safety + param.nr_weight = 0; + param.weight_label = NULL; + param.weight = NULL; + + char cmd[81]; + while(1) + { + FSCANF(fp,"%80s",cmd); + + if(strcmp(cmd,"svm_type")==0) + { + FSCANF(fp,"%80s",cmd); + int i; + for(i=0;svm_type_table[i];i++) + { + if(strcmp(svm_type_table[i],cmd)==0) + { + param.svm_type=i; + break; + } + } + if(svm_type_table[i] == NULL) + { + fprintf(stderr,"unknown svm type.\n"); + return false; + } + } + else if(strcmp(cmd,"kernel_type")==0) + { + FSCANF(fp,"%80s",cmd); + int i; + for(i=0;kernel_type_table[i];i++) + { + if(strcmp(kernel_type_table[i],cmd)==0) + { + param.kernel_type=i; + break; + } + } + if(kernel_type_table[i] == NULL) + { + fprintf(stderr,"unknown kernel function.\n"); + return false; + } + } + else if(strcmp(cmd,"degree")==0) + FSCANF(fp,"%d",¶m.degree); + else if(strcmp(cmd,"gamma")==0) + FSCANF(fp,"%lf",¶m.gamma); + else if(strcmp(cmd,"coef0")==0) + FSCANF(fp,"%lf",¶m.coef0); + else if(strcmp(cmd,"nr_class")==0) + FSCANF(fp,"%d",&model->nr_class); + else if(strcmp(cmd,"total_sv")==0) + FSCANF(fp,"%d",&model->l); + else if(strcmp(cmd,"rho")==0) + { + int n = model->nr_class * (model->nr_class-1)/2; + model->rho = Malloc(double,n); + for(int i=0;irho[i]); + } + else if(strcmp(cmd,"label")==0) + { + int n = model->nr_class; + model->label = Malloc(int,n); + for(int i=0;ilabel[i]); + } + else if(strcmp(cmd,"probA")==0) + { + int n = model->nr_class * (model->nr_class-1)/2; + model->probA = Malloc(double,n); + for(int i=0;iprobA[i]); + } + else if(strcmp(cmd,"probB")==0) + { + int n = model->nr_class * (model->nr_class-1)/2; + model->probB = Malloc(double,n); + for(int i=0;iprobB[i]); + } + else if(strcmp(cmd,"prob_density_marks")==0) + { + int n = 10; // nr_marks + model->prob_density_marks = Malloc(double,n); + for(int i=0;iprob_density_marks[i]); + } + else if(strcmp(cmd,"nr_sv")==0) + { + int n = model->nr_class; + model->nSV = Malloc(int,n); + for(int i=0;inSV[i]); + } + else if(strcmp(cmd,"SV")==0) + { + while(1) + { + int c = getc(fp); + if(c==EOF || c=='\n') break; + } + break; + } + else + { + fprintf(stderr,"unknown text in model file: [%s]\n",cmd); + return false; + } + } + + return true; + +} + +svm_model *svm_load_model(const char *model_file_name) +{ + FILE *fp = fopen(model_file_name,"rb"); + if(fp==NULL) return NULL; + + char *old_locale = setlocale(LC_ALL, NULL); + if (old_locale) { + old_locale = strdup(old_locale); + } + setlocale(LC_ALL, "C"); + + // read parameters + + svm_model *model = Malloc(svm_model,1); + model->rho = NULL; + model->probA = NULL; + model->probB = NULL; + model->prob_density_marks = NULL; + model->sv_indices = NULL; + model->label = NULL; + model->nSV = NULL; + + // read header + if (!read_model_header(fp, model)) + { + fprintf(stderr, "ERROR: fscanf failed to read model\n"); + setlocale(LC_ALL, old_locale); + free(old_locale); + free(model->rho); + free(model->label); + free(model->nSV); + free(model); + return NULL; + } + + // read sv_coef and SV + + int elements = 0; + long pos = ftell(fp); + + max_line_len = 1024; + line = Malloc(char,max_line_len); + char *p,*endptr,*idx,*val; + + while(readline(fp)!=NULL) + { + p = strtok(line,":"); + while(1) + { + p = strtok(NULL,":"); + if(p == NULL) + break; + ++elements; + } + } + elements += model->l; + + fseek(fp,pos,SEEK_SET); + + int m = model->nr_class - 1; + int l = model->l; + model->sv_coef = Malloc(double *,m); + int i; + for(i=0;isv_coef[i] = Malloc(double,l); + model->SV = Malloc(svm_node*,l); + svm_node *x_space = NULL; + if(l>0) x_space = Malloc(svm_node,elements); + + int j=0; + for(i=0;iSV[i] = &x_space[j]; + + p = strtok(line, " \t"); + model->sv_coef[0][i] = strtod(p,&endptr); + for(int k=1;ksv_coef[k][i] = strtod(p,&endptr); + } + + while(1) + { + idx = strtok(NULL, ":"); + val = strtok(NULL, " \t"); + + if(val == NULL) + break; + x_space[j].index = (int) strtol(idx,&endptr,10); + x_space[j].value = strtod(val,&endptr); + + ++j; + } + x_space[j++].index = -1; + } + free(line); + + setlocale(LC_ALL, old_locale); + free(old_locale); + + if (ferror(fp) != 0 || fclose(fp) != 0) + return NULL; + + model->free_sv = 1; // XXX + return model; +} + +void svm_free_model_content(svm_model* model_ptr) +{ + if(model_ptr->free_sv && model_ptr->l > 0 && model_ptr->SV != NULL) + free((void *)(model_ptr->SV[0])); + if(model_ptr->sv_coef) + { + for(int i=0;inr_class-1;i++) + free(model_ptr->sv_coef[i]); + } + + free(model_ptr->SV); + model_ptr->SV = NULL; + + free(model_ptr->sv_coef); + model_ptr->sv_coef = NULL; + + free(model_ptr->rho); + model_ptr->rho = NULL; + + free(model_ptr->label); + model_ptr->label = NULL; + + free(model_ptr->probA); + model_ptr->probA = NULL; + + free(model_ptr->probB); + model_ptr->probB = NULL; + + free(model_ptr->prob_density_marks); + model_ptr->prob_density_marks = NULL; + + free(model_ptr->sv_indices); + model_ptr->sv_indices = NULL; + + free(model_ptr->nSV); + model_ptr->nSV = NULL; +} + +void svm_free_and_destroy_model(svm_model** model_ptr_ptr) +{ + if(model_ptr_ptr != NULL && *model_ptr_ptr != NULL) + { + svm_free_model_content(*model_ptr_ptr); + free(*model_ptr_ptr); + *model_ptr_ptr = NULL; + } +} + +void svm_destroy_param(svm_parameter* param) +{ + free(param->weight_label); + free(param->weight); +} + +const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param) +{ + // svm_type + + int svm_type = param->svm_type; + if(svm_type != C_SVC && + svm_type != NU_SVC && + svm_type != ONE_CLASS && + svm_type != EPSILON_SVR && + svm_type != NU_SVR) + return "unknown svm type"; + + // kernel_type, degree + + int kernel_type = param->kernel_type; + if(kernel_type != LINEAR && + kernel_type != POLY && + kernel_type != RBF && + kernel_type != SIGMOID && + kernel_type != PRECOMPUTED) + return "unknown kernel type"; + + if((kernel_type == POLY || kernel_type == RBF || kernel_type == SIGMOID) && + param->gamma < 0) + return "gamma < 0"; + + if(kernel_type == POLY && param->degree < 0) + return "degree of polynomial kernel < 0"; + + // cache_size,eps,C,nu,p,shrinking + + if(param->cache_size <= 0) + return "cache_size <= 0"; + + if(param->eps <= 0) + return "eps <= 0"; + + if(svm_type == C_SVC || + svm_type == EPSILON_SVR || + svm_type == NU_SVR) + if(param->C <= 0) + return "C <= 0"; + + if(svm_type == NU_SVC || + svm_type == ONE_CLASS || + svm_type == NU_SVR) + if(param->nu <= 0 || param->nu > 1) + return "nu <= 0 or nu > 1"; + + if(svm_type == EPSILON_SVR) + if(param->p < 0) + return "p < 0"; + + if(param->shrinking != 0 && + param->shrinking != 1) + return "shrinking != 0 and shrinking != 1"; + + if(param->probability != 0 && + param->probability != 1) + return "probability != 0 and probability != 1"; + + + // check whether nu-svc is feasible + + if(svm_type == NU_SVC) + { + int l = prob->l; + int max_nr_class = 16; + int nr_class = 0; + int *label = Malloc(int,max_nr_class); + int *count = Malloc(int,max_nr_class); + + int i; + for(i=0;iy[i]; + int j; + for(j=0;jnu*(n1+n2)/2 > min(n1,n2)) + { + free(label); + free(count); + return "specified nu is infeasible"; + } + } + } + free(label); + free(count); + } + + return NULL; +} + +int svm_check_probability_model(const svm_model *model) +{ + return + ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && + model->probA!=NULL && model->probB!=NULL) || + (model->param.svm_type == ONE_CLASS && model->prob_density_marks!=NULL) || + ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && + model->probA!=NULL); +} + +void svm_set_print_string_function(void (*print_func)(const char *)) +{ + if(print_func == NULL) + svm_print_string = &print_string_stdout; + else + svm_print_string = print_func; +} diff --git a/primedev/server/svm.h b/primedev/server/svm.h new file mode 100644 index 000000000..3fb55270b --- /dev/null +++ b/primedev/server/svm.h @@ -0,0 +1,105 @@ +#ifndef _LIBSVM_H +#define _LIBSVM_H + +#define LIBSVM_VERSION 333 + +#ifdef __cplusplus +extern "C" { +#endif + +extern int libsvm_version; + +struct svm_node +{ + int index; + double value; +}; + +struct svm_problem +{ + int l; + double *y; + struct svm_node **x; +}; + +enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */ +enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */ + +struct svm_parameter +{ + int svm_type; + int kernel_type; + int degree; /* for poly */ + double gamma; /* for poly/rbf/sigmoid */ + double coef0; /* for poly/sigmoid */ + + /* these are for training only */ + double cache_size; /* in MB */ + double eps; /* stopping criteria */ + double C; /* for C_SVC, EPSILON_SVR and NU_SVR */ + int nr_weight; /* for C_SVC */ + int *weight_label; /* for C_SVC */ + double* weight; /* for C_SVC */ + double nu; /* for NU_SVC, ONE_CLASS, and NU_SVR */ + double p; /* for EPSILON_SVR */ + int shrinking; /* use the shrinking heuristics */ + int probability; /* do probability estimates */ +}; + +// +// svm_model +// +struct svm_model +{ + struct svm_parameter param; /* parameter */ + int nr_class; /* number of classes, = 2 in regression/one class svm */ + int l; /* total #SV */ + struct svm_node **SV; /* SVs (SV[l]) */ + double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ + double *rho; /* constants in decision functions (rho[k*(k-1)/2]) */ + double *probA; /* pariwise probability information */ + double *probB; + double *prob_density_marks; /* probability information for ONE_CLASS */ + int *sv_indices; /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */ + + /* for classification only */ + + int *label; /* label of each class (label[k]) */ + int *nSV; /* number of SVs for each class (nSV[k]) */ + /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */ + /* XXX */ + int free_sv; /* 1 if svm_model is created by svm_load_model*/ + /* 0 if svm_model is created by svm_train */ +}; + +struct svm_model *svm_train(const struct svm_problem *prob, const struct svm_parameter *param); +void svm_cross_validation(const struct svm_problem *prob, const struct svm_parameter *param, int nr_fold, double *target); + +int svm_save_model(const char *model_file_name, const struct svm_model *model); +struct svm_model *svm_load_model(const char *model_file_name); + +int svm_get_svm_type(const struct svm_model *model); +int svm_get_nr_class(const struct svm_model *model); +void svm_get_labels(const struct svm_model *model, int *label); +void svm_get_sv_indices(const struct svm_model *model, int *sv_indices); +int svm_get_nr_sv(const struct svm_model *model); +double svm_get_svr_probability(const struct svm_model *model); + +double svm_predict_values(const struct svm_model *model, const struct svm_node *x, double* dec_values); +double svm_predict(const struct svm_model *model, const struct svm_node *x); +double svm_predict_probability(const struct svm_model *model, const struct svm_node *x, double* prob_estimates); + +void svm_free_model_content(struct svm_model *model_ptr); +void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr); +void svm_destroy_param(struct svm_parameter *param); + +const char *svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param); +int svm_check_probability_model(const struct svm_model *model); + +void svm_set_print_string_function(void (*print_func)(const char *)); + +#ifdef __cplusplus +} +#endif + +#endif /* _LIBSVM_H */ diff --git a/primedev/thirdparty/openssl-cmake b/primedev/thirdparty/openssl-cmake deleted file mode 160000 index 186966f4a..000000000 --- a/primedev/thirdparty/openssl-cmake +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 186966f4ae15925b0dd7f370cdbb654b00e134e2 diff --git a/primedev/util/base64.cpp b/primedev/util/base64.cpp new file mode 100644 index 000000000..5a232468c --- /dev/null +++ b/primedev/util/base64.cpp @@ -0,0 +1,104 @@ +#include "base64.h" +#include "pch.h" +#include +#include + +static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(BYTE c) +{ + return (isalnum(c) || (c == '+') || (c == '/')); +} + +std::string base64_encode(BYTE const* buf, unsigned int bufLen) +{ + std::string ret; + int i = 0; + int j = 0; + BYTE char_array_3[3]; + BYTE char_array_4[4]; + + while (bufLen--) + { + char_array_3[i++] = *(buf++); + if (i == 3) + { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for (i = 0; (i < 4); i++) + ret += base64_chars[char_array_4[i]]; + i = 0; + } + } + + if (i) + { + for (j = i; j < 3; j++) + char_array_3[j] = '\0'; + + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for (j = 0; (j < i + 1); j++) + ret += base64_chars[char_array_4[j]]; + + while ((i++ < 3)) + ret += '='; + } + + return ret; +} + +std::vector base64_decode(std::string const& encoded_string) +{ + int in_len = encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + BYTE char_array_4[4], char_array_3[3]; + std::vector ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) + { + char_array_4[i++] = encoded_string[in_]; + in_++; + if (i == 4) + { + for (i = 0; i < 4; i++) + char_array_4[i] = base64_chars.find(char_array_4[i]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + ret.push_back(char_array_3[i]); + i = 0; + } + } + + if (i) + { + for (j = i; j < 4; j++) + char_array_4[j] = 0; + + for (j = 0; j < 4; j++) + char_array_4[j] = base64_chars.find(char_array_4[j]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) + ret.push_back(char_array_3[j]); + } + + return ret; +} diff --git a/primedev/util/base64.h b/primedev/util/base64.h new file mode 100644 index 000000000..2da9041a7 --- /dev/null +++ b/primedev/util/base64.h @@ -0,0 +1,8 @@ +#pragma once + +#include +#include +typedef unsigned char BYTE; + +std::string base64_encode(BYTE const* buf, unsigned int bufLen); +std::vector base64_decode(std::string const&); diff --git a/primedev/util/dohworker.cpp b/primedev/util/dohworker.cpp new file mode 100644 index 000000000..ad3b5b3b5 --- /dev/null +++ b/primedev/util/dohworker.cpp @@ -0,0 +1,127 @@ + +#include "dohworker.h" +#include "rapidjson/document.h" +#include "rapidjson/stringbuffer.h" +#include "rapidjson/writer.h" +#include "rapidjson/error/en.h" +#include +#include +#include +#include "masterserver/masterserver.h" +AUTOHOOK_INIT() + +DohWorker* g_DohWorker = new(DohWorker); + +size_t DOHCurlWriteToStringBufferCallback(char* contents, size_t size, size_t nmemb, void* userp) +{ + ((std::string*)userp)->append((char*)contents, size * nmemb); + return size * nmemb; +} + + +std::string DohWorker::GetDOHResolve(std::string domainname) +{ + bool stripHeaders = false; + if ((domainname.find("https") != std::string::npos)) + { + //spdlog::info("[DOH] Found https like headers in {}", domainname); + domainname.erase(0, 8); + stripHeaders = true; + } + if ((domainname.find("http") != std::string::npos) && !stripHeaders) + { + //spdlog::info("[DOH] Found http like headers in {}", domainname); + domainname.erase(0, 7); + stripHeaders = true; + } + if (localresolvcache.find(domainname) != localresolvcache.end()) + { + //spdlog::info("[DOH] Found cache for {}", domainname); + return localresolvcache[domainname]; + } + else + { + return ResolveDomain(domainname); + } +} + +std::string DohWorker::ResolveDomain(std::string domainname) +{ + while(is_resolving) + Sleep(100); + bool stripHeaders = false; + is_resolving = true; + if ((domainname.find("https") != std::string::npos)) + { + //spdlog::info("[DOH] Found https like headers in {}", domainname); + domainname.erase(0, 8); + stripHeaders = true; + } + if ((domainname.find("http") != std::string::npos) && !stripHeaders) + { + //spdlog::info("[DOH] Found http like headers in {}", domainname); + domainname.erase(0, 7); + stripHeaders = true; + } + //spdlog::info("[DOH] Resolving {}", domainname); + CURL* curl = curl_easy_init(); + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 0L); + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0L); + std::string readBuffer; + char* domainnameescaped = curl_easy_escape(curl, domainname.c_str(), domainname.length()); + curl_easy_setopt( + curl, + CURLOPT_URL, + fmt::format( + "https://223.6.6.6/resolve?name={}&type=1&short=1", + domainnameescaped) + .c_str()); + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "GET"); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, DOHCurlWriteToStringBufferCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + + CURLcode result = curl_easy_perform(curl); + + if (result == CURLcode::CURLE_OK) + { + //spdlog::info(readBuffer); + readBuffer = readBuffer.substr(2, readBuffer.length() - 4); + + + if (!readBuffer.empty()) + { + g_DohWorker->localresolvcache.insert_or_assign(domainname, readBuffer); + spdlog::info("[DOH] Successfully resolved {} , result: {}", domainname, readBuffer); + curl_easy_cleanup(curl); + m_bDohAvailable = true; + is_resolving = false; + return readBuffer; + } + else + { + spdlog::info("[DOH] Failed resolving {}, Got data error.", domainname); + curl_easy_cleanup(curl); + m_bDohAvailable = false; + is_resolving = false; + return ""; + } + + //spdlog::error("Failed reading player clantag"); + } + else + { + // DOH failed, we need to fallback to use normal DNS + if(m_bDohAvailable) + spdlog::error("[DOH] CURL failed : error {}", curl_easy_strerror(result)); + m_bDohAvailable = false; + } + + is_resolving = false; + curl_easy_cleanup(curl); + return ""; +} + +ON_DLL_LOAD("client.dll", DohWorkerInitialize, (CModule module)) +{ + g_DohWorker->ResolveDomain("nscn.wolf109909.top"); +} diff --git a/primedev/util/dohworker.h b/primedev/util/dohworker.h new file mode 100644 index 000000000..de20efada --- /dev/null +++ b/primedev/util/dohworker.h @@ -0,0 +1,18 @@ +#pragma once +#include +#include +#include +#include + +class DohWorker +{ +public: + bool m_bDohAvailable = false; + bool is_resolving = false; + std::map localresolvcache; + void ExecuteDefaults(); + std::string ResolveDomain(std::string domainname); + std::string GetDOHResolve(std::string domainname); +}; + +extern DohWorker* g_DohWorker; diff --git a/primedev/vscript/languages/squirrel_re/squirrel/squserdata.h b/primedev/vscript/languages/squirrel_re/squirrel/squserdata.h index 98ede8878..02d820d31 100644 --- a/primedev/vscript/languages/squirrel_re/squirrel/squserdata.h +++ b/primedev/vscript/languages/squirrel_re/squirrel/squserdata.h @@ -8,7 +8,7 @@ struct SQUserData : public SQDelegable { int size; char padding1[4]; - void* (*releasehook)(void* val, int size); + void (*releasehook)(void* val, int size); long long typeId; char data[1]; };