Skip to content

Commit

Permalink
refactor: dont use strings and uint64_t in dave, use snowflake type (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis authored Oct 23, 2024
1 parent 7d0128c commit 83a86f6
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 59 deletions.
4 changes: 2 additions & 2 deletions include/dpp/discordvoiceclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,13 @@ class DPP_EXPORT discord_voice_client : public websocket_client
* @brief The list of users that have E2EE potentially enabled for
* DAVE protocol.
*/
std::set<std::string> dave_mls_user_list;
std::set<dpp::snowflake> dave_mls_user_list;

/**
* @brief The list of users that have left the voice channel but
* not yet removed from MLS group.
*/
std::set<std::string> dave_mls_pending_remove_list;
std::set<dpp::snowflake> dave_mls_pending_remove_list;

/**
* @brief File descriptor for UDP connection
Expand Down
51 changes: 21 additions & 30 deletions src/dpp/dave/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <iostream>
#include <mls/crypto.h>
#include <mls/messages.h>
#include <dpp/export.h>
#include <dpp/snowflake.h>
#include <mls/state.h>
#include <dpp/cluster.h>
#include "mls_key_ratchet.h"
Expand All @@ -50,20 +52,20 @@ struct queued_proposal {
::mlspp::bytes_ns::bytes ref;
};

session::session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept
session::session(dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept
: signing_key_id(auth_session_id), key_pair_context(context), failure_callback(std::move(callback)), creator(cluster)
{
creator.log(dpp::ll_debug, "Creating a new MLS session");
}

session::~session() noexcept = default;

void session::init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept {
void session::init(protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept {
reset();

bot_user_id = self_user_id;

creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + std::to_string(group_id));
creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + group_id.str());
session_protocol_version = version;
session_group_id = std::move(big_endian_bytes_from(group_id).as_vec());

Expand Down Expand Up @@ -123,7 +125,7 @@ catch (const std::exception& e) {
return;
}

std::optional<std::vector<uint8_t>> session::process_proposals(std::vector<uint8_t> proposals, std::set<std::string> const& recognised_user_ids) noexcept
std::optional<std::vector<uint8_t>> session::process_proposals(std::vector<uint8_t> proposals, std::set<dpp::snowflake> const& recognised_user_ids) noexcept
try {
if (!pending_group_state && !current_state) {
creator.log(dpp::ll_debug, "Cannot process proposals without any pending or established MLS group state");
Expand Down Expand Up @@ -183,9 +185,7 @@ try {
for (const auto& proposal_message : messages) {
auto validated_content = state_with_proposals->unwrap(proposal_message);

if (!validate_proposal_message(validated_content.authenticated_content(),
*state_with_proposals,
recognised_user_ids)) {
if (!validate_proposal_message(validated_content.authenticated_content(), *state_with_proposals, recognised_user_ids)) {
return std::nullopt;
}

Expand Down Expand Up @@ -238,9 +238,9 @@ catch (const std::exception& e) {
return std::nullopt;
}

bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<std::string> const& recognised_user_ids) const
bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<dpp::snowflake> const& recognised_user_ids) const
{
std::string uid = user_credential_to_string(cred, session_protocol_version);
dpp::snowflake uid(user_credential_to_string(cred, session_protocol_version));
if (uid.empty()) {
creator.log(dpp::ll_warning, "Attempted to verify credential of unexpected type");
return false;
Expand All @@ -254,7 +254,7 @@ bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<st
return true;
}

bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<std::string> const& recognised_user_ids) const {
bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<dpp::snowflake> const& recognised_user_ids) const {
if (message.wire_format != ::mlspp::WireFormat::mls_public_message) {
creator.log(dpp::ll_warning, "MLS proposal message must be PublicMessage");
TRACK_MLS_ERROR("Invalid proposal wire format");
Expand Down Expand Up @@ -357,7 +357,7 @@ catch (const std::exception& e) {
return failed_t{};
}

std::optional<roster_map> session::process_welcome(std::vector<uint8_t> welcome, std::set<std::string> const& recognised_user_ids) noexcept
std::optional<roster_map> session::process_welcome(std::vector<uint8_t> welcome, std::set<dpp::snowflake> const& recognised_user_ids) noexcept
try {
if (!has_cryptographic_state_for_welcome()) {
creator.log(dpp::ll_warning, "Missing local crypto state necessary to process MLS welcome");
Expand Down Expand Up @@ -461,7 +461,7 @@ bool session::has_cryptographic_state_for_welcome() const noexcept
return join_key_package && join_init_private_key && signature_private_key && hpke_private_key;
}

bool session::verify_welcome_state(::mlspp::State const& state, std::set<std::string> const& recognised_user_ids) const
bool session::verify_welcome_state(::mlspp::State const& state, std::set<dpp::snowflake> const& recognised_user_ids) const
{
if (!mls_external_sender) {
creator.log(dpp::ll_warning, "Cannot verify MLS welcome without an external sender");
Expand Down Expand Up @@ -502,13 +502,13 @@ bool session::verify_welcome_state(::mlspp::State const& state, std::set<std::st
return true;
}

void session::init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept
void session::init_leaf_node(dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept
try {
auto ciphersuite = ciphersuite_for_protocol_version(session_protocol_version);

if (!transient_key) {
if (!signing_key_id.empty()) {
transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id, session_protocol_version);
transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id.str(), session_protocol_version);
if (!transient_key) {
creator.log(dpp::ll_warning, "Did not receive MLS signature private key from get_persisted_key_pair; aborting");
return;
Expand All @@ -522,7 +522,7 @@ try {

signature_private_key = transient_key;

auto self_credential = create_user_credential(self_user_id, session_protocol_version);
auto self_credential = create_user_credential(self_user_id.str(), session_protocol_version);
hpke_private_key = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite));
self_leaf_node = std::make_unique<::mlspp::LeafNode>(
ciphersuite, hpke_private_key->public_key, signature_private_key->public_key, std::move(self_credential),
Expand Down Expand Up @@ -608,7 +608,7 @@ catch (const std::exception& e) {
return {};
}

std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string const& user_id) const noexcept
std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(dpp::snowflake user_id) const noexcept
{
if (!current_state) {
creator.log(dpp::ll_warning, "Cannot get key ratchet without an established MLS group");
Expand All @@ -617,7 +617,7 @@ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string cons

// change the string user ID to a little endian 64 bit user ID
// TODO: Make this use dpp::snowflake
auto u64_user_id = strtoull(user_id.c_str(), nullptr, 10);
uint64_t u64_user_id = user_id;
auto user_id_bytes = ::mlspp::bytes_ns::bytes(sizeof(u64_user_id));
memcpy(user_id_bytes.data(), &u64_user_id, sizeof(u64_user_id));

Expand All @@ -629,14 +629,14 @@ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string cons
return std::make_unique<mls_key_ratchet>(creator, current_state->cipher_suite(), std::move(secret));
}

void session::get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept
void session::get_pairwise_fingerprint(uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept
try {
if (!current_state || !signature_private_key) {
throw std::invalid_argument("No established MLS group");
}

uint64_t remote_user_id = strtoull(user_id.c_str(), nullptr, 10);
uint64_t self_user_id = strtoull(bot_user_id.c_str(), nullptr, 10);
uint64_t remote_user_id = user_id;
uint64_t self_user_id = bot_user_id;

auto it = roster.find(remote_user_id);
if (it == roster.end()) {
Expand Down Expand Up @@ -687,16 +687,7 @@ try {

std::vector<uint8_t> out(hash_len);

int ret = EVP_PBE_scrypt((const char*)data.data(),
data.size(),
salt,
sizeof(salt),
N,
r,
p,
max_mem,
out.data(),
out.size());
int ret = EVP_PBE_scrypt((const char*)data.data(), data.size(), salt, sizeof(salt), N, r, p, max_mem, out.data(), out.size());

if (ret == 1) {
callback(out);
Expand Down
26 changes: 14 additions & 12 deletions src/dpp/dave/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <vector>
#include <map>
#include <set>
#include <dpp/export.h>
#include <dpp/snowflake.h>
#include "persisted_key_pair.h"
#include "key_ratchet.h"
#include "version.h"
Expand Down Expand Up @@ -73,7 +75,7 @@ class session { // NOLINT
* @param auth_session_id auth session id (set to empty string to use a transient key pair)
* @param callback callback for failure
*/
session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept;
session(dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept;

/**
* @brief Destructor
Expand All @@ -90,7 +92,7 @@ class session { // NOLINT
* @param self_user_id bot's user id
* @param transient_key transient private key
*/
void init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;
void init(protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;

/**
* @brief Reset the session to defaults
Expand Down Expand Up @@ -129,7 +131,7 @@ class session { // NOLINT
* @param recognised_user_ids list of recognised user IDs
* @return optional vector to send in reply as commit welcome
*/
std::optional<std::vector<uint8_t>> process_proposals(std::vector<uint8_t> proposals, std::set<std::string> const& recognised_user_ids) noexcept;
std::optional<std::vector<uint8_t>> process_proposals(std::vector<uint8_t> proposals, std::set<dpp::snowflake> const& recognised_user_ids) noexcept;

/**
* @brief Process commit message from discord websocket
Expand All @@ -144,7 +146,7 @@ class session { // NOLINT
* @param recognised_user_ids Recognised user ID list
* @return roster list of people in the vc
*/
std::optional<roster_map> process_welcome(std::vector<uint8_t> welcome, std::set<std::string> const& recognised_user_ids) noexcept;
std::optional<roster_map> process_welcome(std::vector<uint8_t> welcome, std::set<dpp::snowflake> const& recognised_user_ids) noexcept;

/**
* @brief Get the bot user's key package for sending to websocket
Expand All @@ -157,7 +159,7 @@ class session { // NOLINT
* @param user_id User id to get ratchet for
* @return The user's key ratchet for use in an encryptor or decryptor
*/
[[nodiscard]] std::unique_ptr<key_ratchet_interface> get_key_ratchet(std::string const& user_id) const noexcept;
[[nodiscard]] std::unique_ptr<key_ratchet_interface> get_key_ratchet(dpp::snowflake user_id) const noexcept;

/**
* @brief callback for completion of pairwise fingerprint
Expand All @@ -172,15 +174,15 @@ class session { // NOLINT
* @param user_id User ID to get fingerprint for
* @param callback Callback for completion
*/
void get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept;
void get_pairwise_fingerprint(uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept;

private:
/**
* @brief Initialise leaf node
* @param self_user_id Bot user id
* @param transient_key Transient key
*/
void init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;
void init_leaf_node(dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;

/**
* @brief Reset join key
Expand All @@ -204,7 +206,7 @@ class session { // NOLINT
* @param recognised_user_ids list of recognised user IDs
* @return
*/
[[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set<std::string> const& recognised_user_ids) const;
[[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set<dpp::snowflake> const& recognised_user_ids) const;

/**
* @brief Validate proposals message
Expand All @@ -213,15 +215,15 @@ class session { // NOLINT
* @param recognised_user_ids recognised list of user IDs
* @return true if validated
*/
[[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<std::string> const& recognised_user_ids) const;
[[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<dpp::snowflake> const& recognised_user_ids) const;

/**
* @brief Verify that welcome state is valid
* @param state current state
* @param recognised_user_ids list of recognised user IDs
* @return
*/
[[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set<std::string> const& recognised_user_ids) const;
[[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set<dpp::snowflake> const& recognised_user_ids) const;

/**
* @brief Check if can process a commit now
Expand Down Expand Up @@ -260,12 +262,12 @@ class session { // NOLINT
/**
* @brief Signing key id
*/
std::string signing_key_id;
dpp::snowflake signing_key_id;

/**
* @brief The bot's user snowflake ID
*/
std::string bot_user_id;
dpp::snowflake bot_user_id;

/**
* @brief The bot's key pair context
Expand Down
28 changes: 13 additions & 15 deletions src/dpp/voice/enabled/handle_frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,15 @@ void discord_voice_client::update_ratchets(bool force) {
*/
log(ll_debug, "Updating MLS ratchets for " + std::to_string(dave_mls_user_list.size() + 1) + " user(s)");
for (const auto& user : dave_mls_user_list) {
dpp::snowflake u{user};
if (u == creator->me.id) {
if (user == creator->me.id) {
continue;
}
decryptor_list::iterator decryptor;
/* New user join/old user leave - insert new ratchets if they don't exist */
decryptor = mls_state->decryptors.find(u);
decryptor = mls_state->decryptors.find(user.str());
if (decryptor == mls_state->decryptors.end()) {
log(ll_debug, "Inserting decryptor key ratchet for NEW user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version()));
auto [iter, inserted] = mls_state->decryptors.emplace(u, std::make_unique<dpp::dave::decryptor>(*creator));
log(ll_debug, "Inserting decryptor key ratchet for NEW user: " + user.str() + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version()));
auto [iter, inserted] = mls_state->decryptors.emplace(user.str(), std::make_unique<dpp::dave::decryptor>(*creator));
decryptor = iter;
}
decryptor->second->transition_to_key_ratchet(mls_state->dave_session->get_key_ratchet(user), RATCHET_EXPIRY);
Expand All @@ -72,7 +71,7 @@ void discord_voice_client::update_ratchets(bool force) {
if (mls_state->encryptor) {
/* Updating key rachet should always be done on execute transition. Generally after group member add/remove. */
log(ll_debug, "Setting key ratchet for sending audio...");
mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id.str()));
mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id));
}

/**
Expand Down Expand Up @@ -146,7 +145,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->mls_state->transition_id));

/* We should always recognize our own selves, but do we? */
dave_mls_user_list.insert(this->creator->me.id.str());
dave_mls_user_list.insert(this->creator->me.id);

auto r = mls_state->dave_session->process_welcome(dave_header.get_data(), dave_mls_user_list);

Expand Down Expand Up @@ -222,7 +221,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod

/* Remove this user from pending remove list if exist */
for (const auto &user : joining_dave_users) {
dave_mls_pending_remove_list.erase(user);
dave_mls_pending_remove_list.erase(dpp::snowflake(user));
}

log(ll_debug, "New of clients in voice channel: " + std::to_string(joining_dave_users.size()) + " total is " + std::to_string(dave_mls_user_list.size()));
Expand Down Expand Up @@ -298,7 +297,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
}

/* Mark this user for remove on immediate upgrade */
dave_mls_pending_remove_list.insert(u_id.str());
dave_mls_pending_remove_list.insert(u_id);

if (!creator->on_voice_client_disconnect.empty()) {
voice_client_disconnect_t vcd(nullptr, data);
Expand Down Expand Up @@ -575,12 +574,12 @@ void discord_voice_client::reinit_dave_mls_group() {
if (mls_state->dave_session == nullptr) {
mls_state->dave_session = std::make_unique<dave::mls::session>(
*creator,
nullptr, "", [this](std::string const &s1, std::string const &s2) {
nullptr, snowflake(), [this](std::string const &s1, std::string const &s2) {
log(ll_debug, "DAVE: " + s1 + ", " + s2);
});
}

mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key);
mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id, mls_state->mls_key);

auto key_response = mls_state->dave_session->get_marshalled_key_package();
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
Expand Down Expand Up @@ -630,12 +629,11 @@ void discord_voice_client::process_mls_group_rosters(const dave::roster_map &rma
}

dpp::snowflake u_id(k);
auto u_id_str = u_id.str();

log(ll_debug, "Removed user from MLS Group: " + u_id_str);
log(ll_debug, "Removed user from MLS Group: " + u_id.str());

dave_mls_user_list.erase(u_id_str);
dave_mls_pending_remove_list.erase(u_id_str);
dave_mls_user_list.erase(u_id);
dave_mls_pending_remove_list.erase(u_id);

/* Remove this user's key ratchet */
mls_state->decryptors.erase(u_id);
Expand Down

0 comments on commit 83a86f6

Please sign in to comment.