diff --git a/src/emulator/serialization.hpp b/src/emulator/serialization.hpp index 14e2b6af..4693ce3e 100644 --- a/src/emulator/serialization.hpp +++ b/src/emulator/serialization.hpp @@ -22,6 +22,16 @@ namespace utils { a.deserialize(deserializer) } -> std::same_as; }; + template + struct is_optional : std::false_type + { + }; + + template + struct is_optional> : std::true_type + { + }; + namespace detail { template @@ -349,6 +359,12 @@ namespace utils const uint64_t old_size = this->buffer_.size(); #endif + if (this->break_offset_ && this->buffer_.size() <= *this->break_offset_ && + this->buffer_.size() + length > *this->break_offset_) + { + throw std::runtime_error("Break offset reached!"); + } + const auto* byte_buffer = static_cast(buffer); this->buffer_.insert(this->buffer_.end(), byte_buffer, byte_buffer + length); @@ -365,6 +381,7 @@ namespace utils } template + requires(!is_optional::value) void write(const T& object) { constexpr auto is_trivially_copyable = std::is_trivially_copyable_v; @@ -475,8 +492,47 @@ namespace utils return std::move(this->buffer_); } + void set_break_offset(const size_t break_offset) + { + this->break_offset_ = break_offset; + } + + std::optional get_diff(const buffer_serializer& other) const + { + auto& b1 = this->get_buffer(); + auto& b2 = other.get_buffer(); + + const auto s1 = b1.size(); + const auto s2 = b2.size(); + + for (size_t i = 0; i < s1 && i < s2; ++i) + { + if (b1.at(i) != b2.at(i)) + { + return i; + } + } + + if (s1 != s2) + { + return std::min(s1, s2); + } + + return std::nullopt; + } + + void print_diff(const buffer_serializer& other) const + { + const auto diff = this->get_diff(other); + if (diff) + { + printf("Diff at %zd\n", *diff); + } + } + private: std::vector buffer_{}; + std::optional break_offset_{}; }; template <> diff --git a/src/samples/test-sample/CMakeLists.txt b/src/samples/test-sample/CMakeLists.txt index 6a131eff..bd5f2878 100644 --- a/src/samples/test-sample/CMakeLists.txt +++ b/src/samples/test-sample/CMakeLists.txt @@ -9,3 +9,7 @@ list(SORT SRC_FILES) add_executable(test-sample ${SRC_FILES}) momo_assign_source_group(${SRC_FILES}) + +target_link_libraries(test-sample PRIVATE + emulator-common +) diff --git a/src/samples/test-sample/test.cpp b/src/samples/test-sample/test.cpp index 83282f81..d358bcf5 100644 --- a/src/samples/test-sample/test.cpp +++ b/src/samples/test-sample/test.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include using namespace std::literals; @@ -195,7 +195,7 @@ std::optional read_registry_string(const HKEY root, const char* pat return ""; } - return {std::string(data, min(length - 1, sizeof(data)))}; + return {std::string(data, std::min(static_cast(length - 1), sizeof(data)))}; } bool test_registry() @@ -231,6 +231,36 @@ bool test_exceptions() } } +bool test_socket() +{ + network::udp_socket receiver{AF_INET}; + const network::udp_socket sender{AF_INET}; + const network::address destination{"127.0.0.1:28970", AF_INET}; + constexpr std::string_view send_data = "Hello World"; + + if (!receiver.bind(destination)) + { + puts("Failed to bind socket!"); + return false; + } + + if (!sender.send(destination, send_data)) + { + puts("Failed to send data!"); + return false; + } + + const auto response = receiver.receive(); + + if (!response) + { + puts("Failed to recieve data!"); + return false; + } + + return send_data == response->second; +} + void throw_access_violation() { if (do_the_task) @@ -256,7 +286,7 @@ bool test_ud2_exception(void* address) { __try { - static_cast(address)(); + reinterpret_cast(address)(); return false; } __except (EXCEPTION_EXECUTE_HANDLER) @@ -301,12 +331,18 @@ void print_time() puts(res ? "Success" : "Fail"); \ } -int main(int argc, const char* argv[]) +int main(const int argc, const char* argv[]) { - if (argc == 2 && argv[1] == "-time"sv) + bool reproducible = false; + if (argc == 2) { - print_time(); - return 0; + if (argv[1] == "-time"sv) + { + print_time(); + return 0; + } + + reproducible = argv[1] == "-reproducible"sv; } bool valid = true; @@ -320,5 +356,10 @@ int main(int argc, const char* argv[]) RUN_TEST(test_native_exceptions, "Native Exceptions") RUN_TEST(test_tls, "TLS") + if (!reproducible) + { + RUN_TEST(test_socket, "Socket") + } + return valid ? 0 : 1; } diff --git a/src/tools/create-root.bat b/src/tools/create-root.bat index 5a89a64d..d457c7a8 100644 --- a/src/tools/create-root.bat +++ b/src/tools/create-root.bat @@ -110,6 +110,8 @@ CALL :collect_dll mscms.dll CALL :collect_dll ktmw32.dll CALL :collect_dll shcore.dll CALL :collect_dll diagnosticdatasettings.dll +CALL :collect_dll mswsock.dll +CALL :collect_dll umpdc.dll CALL :collect_dll locale.nls diff --git a/src/windows-emulator-test/emulation_test_utils.hpp b/src/windows-emulator-test/emulation_test_utils.hpp index fc483139..72028a01 100644 --- a/src/windows-emulator-test/emulation_test_utils.hpp +++ b/src/windows-emulator-test/emulation_test_utils.hpp @@ -38,14 +38,20 @@ namespace test return env; } - inline windows_emulator create_sample_emulator(emulator_settings settings, emulator_callbacks callbacks = {}) + inline windows_emulator create_sample_emulator(emulator_settings settings, const bool reproducible = false, + emulator_callbacks callbacks = {}) { const auto is_verbose = enable_verbose_logging(); if (is_verbose) { settings.disable_logging = false; - settings.verbose_calls = true; + // settings.verbose_calls = true; + } + + if (reproducible) + { + settings.arguments = {u"-reproducible"}; } settings.application = "c:/test-sample.exe"; @@ -53,13 +59,73 @@ namespace test return windows_emulator{std::move(settings), std::move(callbacks)}; } - inline windows_emulator create_sample_emulator() + inline windows_emulator create_sample_emulator(const bool reproducible = false) { emulator_settings settings{ .disable_logging = true, .use_relative_time = true, }; - return create_sample_emulator(std::move(settings)); + return create_sample_emulator(std::move(settings), reproducible); + } + + inline void bisect_emulation(windows_emulator& emu) + { + utils::buffer_serializer start_state{}; + emu.serialize(start_state); + + emu.start(); + const auto limit = emu.process().executed_instructions; + + const auto reset_emulator = [&] { + utils::buffer_deserializer deserializer{start_state.get_buffer()}; + emu.deserialize(deserializer); + }; + + const auto get_state_for_count = [&](const size_t count) { + reset_emulator(); + emu.start({}, count); + + utils::buffer_serializer state{}; + emu.serialize(state); + return state; + }; + + const auto has_diff_after_count = [&](const size_t count) { + const auto s1 = get_state_for_count(count); + const auto s2 = get_state_for_count(count); + + return s1.get_diff(s2).has_value(); + }; + + if (!has_diff_after_count(limit)) + { + puts("Emulation has no diff"); + } + + auto upper_bound = limit; + decltype(upper_bound) lower_bound = 0; + + printf("Bounds: %" PRIx64 " - %" PRIx64 "\n", lower_bound, upper_bound); + + while (lower_bound + 1 < upper_bound) + { + const auto diff = (upper_bound - lower_bound); + const auto pivot = lower_bound + (diff / 2); + + const auto has_diff = has_diff_after_count(pivot); + + auto* bound = has_diff ? &upper_bound : &lower_bound; + *bound = pivot; + + printf("Bounds: %" PRIx64 " - %" PRIx64 "\n", lower_bound, upper_bound); + } + + (void)get_state_for_count(lower_bound); + + const auto rip = emu.emu().read_instruction_pointer(); + + printf("Diff detected after 0x%" PRIx64 " instructions at 0x%" PRIx64 " (%s)\n", lower_bound, rip, + emu.process().mod_manager.find_name(rip)); } } diff --git a/src/windows-emulator-test/serialization_test.cpp b/src/windows-emulator-test/serialization_test.cpp index 95461785..990e48f8 100644 --- a/src/windows-emulator-test/serialization_test.cpp +++ b/src/windows-emulator-test/serialization_test.cpp @@ -4,7 +4,7 @@ namespace test { TEST(SerializationTest, ResettingEmulatorWorks) { - auto emu = create_sample_emulator(); + auto emu = create_sample_emulator(true); utils::buffer_serializer start_state{}; emu.serialize(start_state); @@ -31,7 +31,7 @@ namespace test TEST(SerializationTest, SerializedDataIsReproducible) { - auto emu1 = create_sample_emulator(); + auto emu1 = create_sample_emulator(true); emu1.start(); ASSERT_TERMINATED_SUCCESSFULLY(emu1); @@ -55,7 +55,7 @@ namespace test TEST(SerializationTest, EmulationIsReproducible) { - auto emu1 = create_sample_emulator(); + auto emu1 = create_sample_emulator(true); emu1.start(); ASSERT_TERMINATED_SUCCESSFULLY(emu1); @@ -63,7 +63,7 @@ namespace test utils::buffer_serializer serializer1{}; emu1.serialize(serializer1); - auto emu2 = create_sample_emulator(); + auto emu2 = create_sample_emulator(true); emu2.start(); ASSERT_TERMINATED_SUCCESSFULLY(emu2); @@ -76,7 +76,7 @@ namespace test TEST(SerializationTest, DeserializedEmulatorBehavesLikeSource) { - auto emu = create_sample_emulator(); + auto emu = create_sample_emulator(true); emu.start({}, 100); utils::buffer_serializer serializer{}; diff --git a/src/windows-emulator-test/time_test.cpp b/src/windows-emulator-test/time_test.cpp index 38555971..bb4b1e01 100644 --- a/src/windows-emulator-test/time_test.cpp +++ b/src/windows-emulator-test/time_test.cpp @@ -16,7 +16,7 @@ namespace test .use_relative_time = false, }; - auto emu = create_sample_emulator(settings, callbacks); + auto emu = create_sample_emulator(settings, false, callbacks); emu.start(); constexpr auto prefix = "Time: "sv; diff --git a/src/windows-emulator/devices/afd_endpoint.cpp b/src/windows-emulator/devices/afd_endpoint.cpp index 1a364eb4..0068561b 100644 --- a/src/windows-emulator/devices/afd_endpoint.cpp +++ b/src/windows-emulator/devices/afd_endpoint.cpp @@ -22,6 +22,189 @@ namespace // ... }; + struct win_sockaddr + { + int16_t sa_family; + uint8_t sa_data[14]; + }; + + struct win_sockaddr_in + { + int16_t sin_family; + uint16_t sin_port; + in_addr sin_addr; + uint8_t sin_zero[8]; + }; + + struct win_sockaddr_in6 + { + int16_t sin6_family; + uint16_t sin6_port; + uint32_t sin6_flowinfo; + in6_addr sin6_addr; + uint32_t sin6_scope_id; + }; + + static_assert(sizeof(win_sockaddr) == 16); + static_assert(sizeof(win_sockaddr_in) == 16); + static_assert(sizeof(win_sockaddr_in6) == 28); + + static_assert(sizeof(win_sockaddr_in::sin_addr) == 4); + static_assert(sizeof(win_sockaddr_in6::sin6_addr) == 16); + static_assert(sizeof(win_sockaddr_in6::sin6_flowinfo) == sizeof(sockaddr_in6::sin6_flowinfo)); + static_assert(sizeof(win_sockaddr_in6::sin6_scope_id) == sizeof(sockaddr_in6::sin6_scope_id)); + + const std::map address_family_map{ + {0, AF_UNSPEC}, // + {2, AF_INET}, // + {23, AF_INET6}, // + }; + + const std::map socket_type_map{ + {0, 0}, // + {1, SOCK_STREAM}, // + {2, SOCK_DGRAM}, // + {3, SOCK_RAW}, // + {4, SOCK_RDM}, // + }; + + const std::map socket_protocol_map{ + {0, 0}, // + {6, IPPROTO_TCP}, // + {17, IPPROTO_UDP}, // + {255, IPPROTO_RAW}, // + }; + + int16_t translate_host_to_win_address_family(const int host_af) + { + for (auto& entry : address_family_map) + { + if (entry.second == host_af) + { + return static_cast(entry.first); + } + } + + throw std::runtime_error("Unknown host address family: " + std::to_string(host_af)); + } + + int translate_win_to_host_address_family(const int win_af) + { + const auto entry = address_family_map.find(win_af); + if (entry != address_family_map.end()) + { + return entry->second; + } + + throw std::runtime_error("Unknown address family: " + std::to_string(win_af)); + } + + int translate_win_to_host_type(const int win_type) + { + const auto entry = socket_type_map.find(win_type); + if (entry != socket_type_map.end()) + { + return entry->second; + } + + throw std::runtime_error("Unknown socket type: " + std::to_string(win_type)); + } + + int translate_win_to_host_protocol(const int win_protocol) + { + const auto entry = socket_protocol_map.find(win_protocol); + if (entry != socket_protocol_map.end()) + { + return entry->second; + } + + throw std::runtime_error("Unknown socket protocol: " + std::to_string(win_protocol)); + } + + std::vector convert_to_win_address(const network::address& a) + { + if (a.is_ipv4()) + { + win_sockaddr_in win_addr{}; + win_addr.sin_family = translate_host_to_win_address_family(a.get_family()); + win_addr.sin_port = htons(a.get_port()); + memcpy(&win_addr.sin_addr, &a.get_in_addr().sin_addr, sizeof(win_addr.sin_addr)); + + const auto ptr = reinterpret_cast(&win_addr); + return {ptr, ptr + sizeof(win_addr)}; + } + + if (a.is_ipv6()) + { + win_sockaddr_in6 win_addr{}; + win_addr.sin6_family = translate_host_to_win_address_family(a.get_family()); + win_addr.sin6_port = htons(a.get_port()); + + auto& addr = a.get_in6_addr(); + memcpy(&win_addr.sin6_addr, &addr.sin6_addr, sizeof(win_addr.sin6_addr)); + win_addr.sin6_flowinfo = addr.sin6_flowinfo; + win_addr.sin6_scope_id = addr.sin6_scope_id; + + const auto ptr = reinterpret_cast(&win_addr); + return {ptr, ptr + sizeof(win_addr)}; + } + + throw std::runtime_error("Unsupported host address family for conversion: " + std::to_string(a.get_family())); + } + + network::address convert_to_host_address(const std::span data) + { + if (data.size() < sizeof(win_sockaddr)) + { + throw std::runtime_error("Bad address size"); + } + + win_sockaddr win_addr{}; + memcpy(&win_addr, data.data(), sizeof(win_addr)); + + const auto family = translate_win_to_host_address_family(win_addr.sa_family); + + network::address a{}; + + if (family == AF_INET) + { + if (data.size() < sizeof(win_sockaddr_in)) + { + throw std::runtime_error("Bad IPv4 address size"); + } + + win_sockaddr_in win_addr4{}; + memcpy(&win_addr4, data.data(), sizeof(win_addr4)); + + a.set_ipv4(win_addr4.sin_addr); + a.set_port(ntohs(win_addr4.sin_port)); + + return a; + } + + if (family == AF_INET6) + { + if (data.size() < sizeof(win_sockaddr_in6)) + { + throw std::runtime_error("Bad IPv6 address size"); + } + + win_sockaddr_in6 win_addr6{}; + memcpy(&win_addr6, data.data(), sizeof(win_addr6)); + + a.set_ipv6(win_addr6.sin6_addr); + a.set_port(ntohs(win_addr6.sin6_port)); + + auto& addr = a.get_in6_addr(); + addr.sin6_flowinfo = win_addr6.sin6_flowinfo; + addr.sin6_scope_id = win_addr6.sin6_scope_id; + + return a; + } + + throw std::runtime_error("Unsupported win address family for conversion: " + std::to_string(family)); + } + afd_creation_data get_creation_data(windows_emulator& win_emu, const io_device_creation_data& data) { if (!data.buffer || data.length < sizeof(afd_creation_data)) @@ -216,8 +399,11 @@ namespace const auto& data = *this->creation_data; - // TODO: values map to windows values; might not be the case for other platforms - const auto sock = socket(data.address_family, data.type, data.protocol); + const auto af = translate_win_to_host_address_family(data.address_family); + const auto type = translate_win_to_host_type(data.type); + const auto protocol = translate_win_to_host_protocol(data.protocol); + + const auto sock = socket(af, type, protocol); if (sock == INVALID_SOCKET) { throw std::runtime_error("Failed to create socket!"); @@ -290,20 +476,20 @@ namespace void deserialize(utils::buffer_deserializer& buffer) override { - buffer.read(this->creation_data); + buffer.read_optional(this->creation_data); this->setup(); - buffer.read(this->require_poll_); - buffer.read(this->delayed_ioctl_); - buffer.read(this->timeout_); + buffer.read_optional(this->require_poll_); + buffer.read_optional(this->delayed_ioctl_); + buffer.read_optional(this->timeout_); } void serialize(utils::buffer_serializer& buffer) const override { - buffer.write(this->creation_data); - buffer.write(this->require_poll_); - buffer.write(this->delayed_ioctl_); - buffer.write(this->timeout_); + buffer.write_optional(this->creation_data); + buffer.write_optional(this->require_poll_); + buffer.write_optional(this->delayed_ioctl_); + buffer.write_optional(this->timeout_); } NTSTATUS io_control(windows_emulator& win_emu, const io_device_context& c) override @@ -339,7 +525,7 @@ namespace NTSTATUS ioctl_bind(windows_emulator& win_emu, const io_device_context& c) const { - const auto data = win_emu.emu().read_memory(c.input_buffer, c.input_buffer_length); + auto data = win_emu.emu().read_memory(c.input_buffer, c.input_buffer_length); constexpr auto address_offset = 4; @@ -348,10 +534,7 @@ namespace return STATUS_BUFFER_TOO_SMALL; } - const auto* address = reinterpret_cast(data.data() + address_offset); - const auto address_size = static_cast(data.size() - address_offset); - - const network::address addr(address, address_size); + const auto addr = convert_to_host_address(std::span(data).subspan(address_offset)); if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR) { @@ -431,28 +614,19 @@ namespace const auto receive_info = emu.read_memory>>(c.input_buffer); const auto buffer = emu.read_memory>>(receive_info.BufferArray); - std::vector address{}; - - unsigned long address_length = 0x1000; - if (receive_info.AddressLength) - { - address_length = emu.read_memory(receive_info.AddressLength); - } - - address.resize(std::clamp(address_length, 1UL, 0x1000UL)); - if (!buffer.len || buffer.len > 0x10000 || !buffer.buf) { return STATUS_INVALID_PARAMETER; } - auto fromlength = static_cast(address.size()); + network::address from{}; + auto from_length = from.get_max_size(); std::vector data{}; data.resize(buffer.len); const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast(data.size()), 0, - reinterpret_cast(address.data()), &fromlength); + &from.get_addr(), &from_length); if (recevied_data < 0) { @@ -466,13 +640,20 @@ namespace return STATUS_UNSUCCESSFUL; } + assert(from.get_size() == from_length); + const auto data_size = std::min(data.size(), static_cast(recevied_data)); emu.write_memory(buffer.buf, data.data(), data_size); - if (receive_info.Address && address_length) + const auto win_from = convert_to_win_address(from); + + if (receive_info.Address && receive_info.AddressLength) { - const auto address_size = std::min(address.size(), static_cast(address_length)); - emu.write_memory(receive_info.Address, address.data(), address_size); + const emulator_object address_length{emu, receive_info.AddressLength}; + const auto address_size = std::min(win_from.size(), static_cast(address_length.read())); + + emu.write_memory(receive_info.Address, win_from.data(), address_size); + address_length.write(static_cast(address_size)); } if (c.io_status_block) @@ -497,17 +678,15 @@ namespace const auto send_info = emu.read_memory>>(c.input_buffer); const auto buffer = emu.read_memory>>(send_info.BufferArray); - const auto address = emu.read_memory(send_info.TdiConnInfo.RemoteAddress, - static_cast(send_info.TdiConnInfo.RemoteAddressLength)); - - const network::address target(reinterpret_cast(address.data()), - static_cast(address.size())); + auto address_buffer = emu.read_memory(send_info.TdiConnInfo.RemoteAddress, + static_cast(send_info.TdiConnInfo.RemoteAddressLength)); + const auto target = convert_to_host_address(address_buffer); const auto data = emu.read_memory(buffer.buf, buffer.len); const auto sent_data = sendto(*this->s_, reinterpret_cast(data.data()), static_cast(data.size()), - 0 /* ? */, &target.get_addr(), target.get_size()); + 0 /* TODO */, &target.get_addr(), target.get_size()); if (sent_data < 0) {