Skip to content

Commit

Permalink
move packet to protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
lionkor committed Jan 17, 2024
1 parent 9e99177 commit 9502048
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 102 deletions.
2 changes: 1 addition & 1 deletion deps/BeamMP-Protocol
30 changes: 8 additions & 22 deletions include/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "State.h"
#include "Sync.h"
#include "Transport.h"
#include "Packet.h"
#include <boost/asio.hpp>
#include <boost/thread/scoped_thread.hpp>
#include <cstdint>
Expand All @@ -16,21 +17,6 @@ using VehicleID = uint16_t;

using namespace boost::asio;

struct Packet {
bmp::Purpose purpose;
bmp::Flags flags;

/// Returns data with consideration to flags.
std::vector<uint8_t> get_readable_data() const;

/// Sets flags (e.g. compression flag) if the data is above some threshold,
/// and compresses the data.
/// Returns the header needed to send this packet.
[[nodiscard]] bmp::Header finalize();

/// Raw (potentially compressed) data -- do not read directly to deserialize from.
std::vector<uint8_t> raw_data;
};

struct Client {
using Ptr = std::shared_ptr<Client>;
Expand All @@ -43,9 +29,9 @@ struct Client {
Sync<std::unordered_map<std::string /* identifier */, std::string /* value */>> identifiers;

/// Reads a single packet from the TCP stream. Blocks all other reads (not writes).
Packet tcp_read();
bmp::Packet tcp_read();
/// Writes the packet to the TCP stream. Blocks all other writes.
void tcp_write(Packet& packet);
void tcp_write(bmp::Packet& packet);
/// Writes the specified to the TCP stream without a header or any metadata - use in
/// conjunction with something else. Blocks other writes.
void tcp_write_file_raw(const std::filesystem::path& path);
Expand Down Expand Up @@ -88,13 +74,13 @@ class Network {
~Network();

/// Reads a packet from the given UDP socket, returning the client's endpoint as an out-argument.
Packet udp_read(ip::udp::endpoint& out_ep);
bmp::Packet udp_read(ip::udp::endpoint& out_ep);
/// Sends a packet to the specified UDP endpoint via the UDP socket.
void udp_write(Packet& packet, const ip::udp::endpoint& to_ep);
void udp_write(bmp::Packet& packet, const ip::udp::endpoint& to_ep);

void disconnect(ClientID id, const std::string& msg);

void handle_packet(ClientID i, const Packet& packet);
void handle_packet(ClientID i, const bmp::Packet& packet);

private:
void udp_read_main();
Expand All @@ -121,9 +107,9 @@ class Network {
Sync<bool> m_shutdown { false };
ip::udp::socket m_udp_socket { m_io };

void handle_identification(ClientID id, const Packet& packet, std::shared_ptr<Client>& client);
void handle_identification(ClientID id, const bmp::Packet& packet, std::shared_ptr<Client>& client);

void handle_authentication(ClientID id, const Packet& packet, std::shared_ptr<Client>& client);
void handle_authentication(ClientID id, const bmp::Packet& packet, std::shared_ptr<Client>& client);

/// On failure, throws an exception with the error for the client.
static void authenticate_user(const std::string& public_key, std::shared_ptr<Client>& client);
Expand Down
99 changes: 20 additions & 79 deletions src/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "Environment.h"
#include "Http.h"
#include "LuaAPI.h"
#include "Packet.h"
#include "ProtocolVersion.h"
#include "ServerInfo.h"
#include "TLuaEngine.h"
Expand All @@ -24,53 +25,9 @@

#include <doctest/doctest.h>

std::vector<uint8_t> Packet::get_readable_data() const {
if ((flags & bmp::Flags::ZstdCompressed) != 0) {
return bmp::zstd_decompress(raw_data);
} else {
return raw_data;
}
}

TEST_CASE("Packet finalize") {
Packet packet {
.purpose = bmp::Purpose::Invalid,
};
SUBCASE("No compression, under threshold") {
packet.raw_data = std::vector<uint8_t>(bmp::COMPRESSION_THRESHOLD - 1, 5);
(void)packet.finalize();
// not compressed, still the same
CHECK(std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; }));
;
// no compression flag
CHECK_EQ(packet.flags & bmp::Flags::ZstdCompressed, 0);
}
SUBCASE("Compression via threshold") {
packet.raw_data = std::vector<uint8_t>(bmp::COMPRESSION_THRESHOLD + 1, 5);
(void)packet.finalize();
// compressed, not the exact same
CHECK(!std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; }));
// decompressable
CHECK_NOTHROW(bmp::zstd_decompress(packet.raw_data));
// compression flag set
CHECK_NE(packet.flags & bmp::Flags::ZstdCompressed, 0);
}
SUBCASE("Compression flag") {
packet.raw_data = std::vector<uint8_t>(bmp::COMPRESSION_THRESHOLD - 1, 5);
packet.flags = bmp::Flags(packet.flags | bmp::Flags::ZstdCompressed);
(void)packet.finalize();
// compressed, not the exact same
CHECK(!std::all_of(packet.raw_data.begin(), packet.raw_data.end(), [](uint8_t value) { return value == 5; }));
// decompressable
CHECK_NOTHROW(bmp::zstd_decompress(packet.raw_data));
// compression flag set
CHECK_NE(packet.flags & bmp::Flags::ZstdCompressed, 0);
}
}

Packet Client::tcp_read() {
bmp::Packet Client::tcp_read() {
std::unique_lock lock(m_tcp_read_mtx);
Packet packet {};
bmp::Packet packet {};
std::vector<uint8_t> header_buffer(bmp::Header::SERIALIZED_SIZE);
read(m_tcp_socket, buffer(header_buffer));
bmp::Header hdr {};
Expand All @@ -83,7 +40,7 @@ Packet Client::tcp_read() {
return packet;
}

void Client::tcp_write(Packet& packet) {
void Client::tcp_write(bmp::Packet& packet) {
beammp_tracef("Sending 0x{:x} to {}", int(packet.purpose), id);
// acquire a lock to avoid writing a header, then being interrupted by another write
std::unique_lock lock(m_tcp_write_mtx);
Expand Down Expand Up @@ -156,27 +113,11 @@ void Client::tcp_main() {
beammp_debugf("TCP thread stopped for client {}", id);
}

bmp::Header Packet::finalize() {
// the user can force zstd compression on before setting data to force compression,
// otherwise the threshold is used.
if ((flags & bmp::Flags::ZstdCompressed) != 0
|| raw_data.size() > bmp::COMPRESSION_THRESHOLD) {
flags = bmp::Flags(flags | bmp::Flags::ZstdCompressed);
raw_data = bmp::zstd_compress(raw_data);
}
return {
.purpose = purpose,
.flags = flags,
.rsv = 0,
.size = static_cast<uint32_t>(raw_data.size()),
};
}

Packet Network::udp_read(ip::udp::endpoint& out_ep) {
bmp::Packet Network::udp_read(ip::udp::endpoint& out_ep) {
// maximum we can ever expect from udp
static thread_local std::vector<uint8_t> s_buffer(std::numeric_limits<uint16_t>::max());
m_udp_socket.receive_from(buffer(s_buffer), out_ep, {});
Packet packet;
bmp::Packet packet;
bmp::Header header {};
auto offset = header.deserialize_from(s_buffer);
if (header.flags != bmp::Flags::None) {
Expand All @@ -188,7 +129,7 @@ Packet Network::udp_read(ip::udp::endpoint& out_ep) {
return packet;
}

void Network::udp_write(Packet& packet, const ip::udp::endpoint& to_ep) {
void Network::udp_write(bmp::Packet& packet, const ip::udp::endpoint& to_ep) {
auto header = packet.finalize();
std::vector<uint8_t> data(header.size + bmp::Header::SERIALIZED_SIZE);
auto offset = header.serialize_to(data);
Expand Down Expand Up @@ -288,7 +229,7 @@ void Network::udp_read_main() {
endpoints->emplace(ep, id);
// now transfer them to the next state
beammp_debugf("Client {} successfully connected via UDP", client->id);
Packet state_change {
bmp::Packet state_change {
.purpose = bmp::Purpose::StateChangeModDownload,
};
client->tcp_write(state_change);
Expand Down Expand Up @@ -326,7 +267,7 @@ void Network::disconnect(ClientID id, const std::string& msg) {
});
clients->erase(id);
}
void Network::handle_packet(ClientID id, const Packet& packet) {
void Network::handle_packet(ClientID id, const bmp::Packet& packet) {
std::shared_ptr<Client> client;
{
auto clients = m_clients.synchronize();
Expand Down Expand Up @@ -358,23 +299,23 @@ void Network::handle_packet(ClientID id, const Packet& packet) {
break;
}
}
void Network::handle_identification(ClientID id, const Packet& packet, std::shared_ptr<Client>& client) {
void Network::handle_identification(ClientID id, const bmp::Packet& packet, std::shared_ptr<Client>& client) {
switch (packet.purpose) {
case bmp::ProtocolVersion: {
struct bmp::ProtocolVersion protocol_version { };
protocol_version.deserialize_from(packet.get_readable_data());
if (protocol_version.version.major != 1) {
beammp_debugf("{}: Protocol version bad", id);
// version bad
Packet protocol_v_bad_packet {
bmp::Packet protocol_v_bad_packet {
.purpose = bmp::ProtocolVersionBad,
};
client->tcp_write(protocol_v_bad_packet);
disconnect(id, fmt::format("bad protocol version: {}.{}.{}", protocol_version.version.major, protocol_version.version.minor, protocol_version.version.patch));
} else {
beammp_debugf("{}: Protocol version ok", id);
// version ok
Packet protocol_v_ok_packet {
bmp::Packet protocol_v_ok_packet {
.purpose = bmp::ProtocolVersionOk,
};
client->tcp_write(protocol_v_ok_packet);
Expand Down Expand Up @@ -408,14 +349,14 @@ void Network::handle_identification(ClientID id, const Packet& packet, std::shar
.value = "Official BeamMP Server (BeamMP Ltd.)",
},
};
Packet sinfo_packet {
bmp::Packet sinfo_packet {
.purpose = bmp::ServerInfo,
.raw_data = std::vector<uint8_t>(1024),
};
sinfo.serialize_to(sinfo_packet.raw_data);
client->tcp_write(sinfo_packet);
// now transfer to next state
Packet auth_state {
bmp::Packet auth_state {
.purpose = bmp::StateChangeAuthentication,
};
client->tcp_write(auth_state);
Expand Down Expand Up @@ -468,7 +409,7 @@ void Network::authenticate_user(const std::string& public_key, std::shared_ptr<C
}
}

void Network::handle_authentication(ClientID id, const Packet& packet, std::shared_ptr<Client>& client) {
void Network::handle_authentication(ClientID id, const bmp::Packet& packet, std::shared_ptr<Client>& client) {
switch (packet.purpose) {
case bmp::Purpose::PlayerPublicKey: {
auto packet_data = packet.get_readable_data();
Expand All @@ -479,7 +420,7 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar
// propragate to client and disconnect
auto err = std::string(e.what());
beammp_errorf("Client {} failed to authenticate: {}", id, err);
Packet auth_fail_packet {
bmp::Packet auth_fail_packet {
.purpose = bmp::Purpose::AuthFailed,
.raw_data = std::vector<uint8_t>(err.begin(), err.end()),
};
Expand All @@ -504,14 +445,14 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar
});

if (NotAllowed) {
Packet auth_fail_packet {
bmp::Packet auth_fail_packet {
.purpose = bmp::Purpose::PlayerRejected
};
client->tcp_write(auth_fail_packet);
disconnect(id, "Rejected by a plugin");
return;
} else if (NotAllowedWithReason) {
Packet auth_fail_packet {
bmp::Packet auth_fail_packet {
.purpose = bmp::Purpose::PlayerRejected,
.raw_data = std::vector<uint8_t>(Reason.begin(), Reason.end()),
};
Expand All @@ -521,7 +462,7 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar
}
beammp_debugf("Client {} successfully authenticated as {} '{}'", id, client->role.get(), client->name.get());
// send auth ok since auth succeeded
Packet auth_ok {
bmp::Packet auth_ok {
.purpose = bmp::Purpose::AuthOk,
.raw_data = std::vector<uint8_t>(4),
};
Expand All @@ -534,7 +475,7 @@ void Network::handle_authentication(ClientID id, const Packet& packet, std::shar

// send the udp start packet, which should get the client to start udp with
// this packet as the first message
Packet udp_start {
bmp::Packet udp_start {
.purpose = bmp::Purpose::StartUDP,
.raw_data = std::vector<uint8_t>(8),
};
Expand Down

0 comments on commit 9502048

Please sign in to comment.