Skip to content

Commit

Permalink
feat: crud users synchronously (#3928)
Browse files Browse the repository at this point in the history
* fix: make clients use auth by default

* fix: let skip auth flag only affect verify

* feat: tablets get user table remotely

* fix: use FLAGS_system_table_replica_num for user table

* feat: consistent user cruds

* fix: pass instance of tablet and nameserver into auth lambda to allow locking

* feat: best effort try to flush user data to all tablets

* fix: lock scope

* fix: stop user sync thread safely

* fix: default values for user table columns
  • Loading branch information
oh2024 authored May 20, 2024
1 parent 5bbf9e3 commit 21184d5
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 149 deletions.
13 changes: 5 additions & 8 deletions src/auth/user_access_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,17 @@ UserAccessManager::UserAccessManager(IteratorFactory iterator_factory)
UserAccessManager::~UserAccessManager() { StopSyncTask(); }

void UserAccessManager::StartSyncTask() {
sync_task_running_ = true;
sync_task_thread_ = std::thread([this] {
while (sync_task_running_) {
sync_task_thread_ = std::thread([this, fut = stop_promise_.get_future()] {
while (true) {
SyncWithDB();
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (fut.wait_for(std::chrono::minutes(15)) != std::future_status::timeout) return;
}
});
}

void UserAccessManager::StopSyncTask() {
sync_task_running_ = false;
if (sync_task_thread_.joinable()) {
sync_task_thread_.join();
}
stop_promise_.set_value();
sync_task_thread_.join();
}

void UserAccessManager::SyncWithDB() {
Expand Down
6 changes: 3 additions & 3 deletions src/auth/user_access_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <chrono>
#include <functional>
#include <future>
#include <memory>
#include <string>
#include <thread>
Expand All @@ -38,14 +39,13 @@ class UserAccessManager {

~UserAccessManager();
bool IsAuthenticated(const std::string& host, const std::string& username, const std::string& password);
void SyncWithDB();

private:
IteratorFactory user_table_iterator_factory_;
RefreshableMap<std::string, std::string> user_map_;
std::atomic<bool> sync_task_running_{false};
std::thread sync_task_thread_;

void SyncWithDB();
std::promise<void> stop_promise_;
void StartSyncTask();
void StopSyncTask();
};
Expand Down
5 changes: 4 additions & 1 deletion src/base/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ enum ReturnCode {
kSQLRunError = 1001,
kRPCRunError = 1002,
kServerConnError = 1003,
kRPCError = 1004 // brpc controller error
kRPCError = 1004, // brpc controller error

// auth
kFlushPrivilegesFailed = 1100 // brpc controller error
};

struct Status {
Expand Down
28 changes: 28 additions & 0 deletions src/client/ns_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <utility>

#include "base/strings.h"
#include "ns_client.h"

DECLARE_int32(request_timeout_ms);
namespace openmldb {
Expand Down Expand Up @@ -302,6 +303,33 @@ bool NsClient::CreateTable(const ::openmldb::nameserver::TableInfo& table_info,

bool NsClient::DropTable(const std::string& name, std::string& msg) { return DropTable(GetDb(), name, msg); }

bool NsClient::PutUser(const std::string& host, const std::string& name, const std::string& password) {
::openmldb::nameserver::PutUserRequest request;
request.set_host(host);
request.set_name(name);
request.set_password(password);
::openmldb::nameserver::GeneralResponse response;
bool ok = client_.SendRequest(&::openmldb::nameserver::NameServer_Stub::PutUser, &request, &response,
FLAGS_request_timeout_ms, 1);
if (ok && response.code() == 0) {
return true;
}
return false;
}

bool NsClient::DeleteUser(const std::string& host, const std::string& name) {
::openmldb::nameserver::DeleteUserRequest request;
request.set_host(host);
request.set_name(name);
::openmldb::nameserver::GeneralResponse response;
bool ok = client_.SendRequest(&::openmldb::nameserver::NameServer_Stub::DeleteUser, &request, &response,
FLAGS_request_timeout_ms, 1);
if (ok && response.code() == 0) {
return true;
}
return false;
}

bool NsClient::DropTable(const std::string& db, const std::string& name, std::string& msg) {
::openmldb::nameserver::DropTableRequest request;
request.set_name(name);
Expand Down
4 changes: 4 additions & 0 deletions src/client/ns_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ class NsClient : public Client {

bool DropTable(const std::string& name, std::string& msg); // NOLINT

bool PutUser(const std::string& host, const std::string& name, const std::string& password); // NOLINT

bool DeleteUser(const std::string& host, const std::string& name); // NOLINT

bool DropTable(const std::string& db, const std::string& name,
std::string& msg); // NOLINT

Expand Down
12 changes: 12 additions & 0 deletions src/client/tablet_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "codec/sql_rpc_row_codec.h"
#include "common/timer.h"
#include "sdk/sql_request_row.h"
#include "tablet_client.h"

DECLARE_int32(request_max_retry);
DECLARE_int32(request_timeout_ms);
Expand Down Expand Up @@ -1414,5 +1415,16 @@ bool TabletClient::GetAndFlushDeployStats(::openmldb::api::DeployStatsResponse*
return ok && res->code() == 0;
}

bool TabletClient::FlushPrivileges() {
::openmldb::api::EmptyRequest request;
::openmldb::api::GeneralResponse response;

bool ok = client_.SendRequest(&::openmldb::api::TabletServer_Stub::FlushPrivileges, &request, &response,
FLAGS_request_timeout_ms, 1);
if (ok && response.code() == 0) {
return true;
}
return false;
}
} // namespace client
} // namespace openmldb
2 changes: 2 additions & 0 deletions src/client/tablet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ class TabletClient : public Client {

bool GetAndFlushDeployStats(::openmldb::api::DeployStatsResponse* res);

bool FlushPrivileges();

private:
base::Status LoadTableInternal(const ::openmldb::api::TableMeta& table_meta, std::shared_ptr<TaskInfo> task_info);

Expand Down
13 changes: 4 additions & 9 deletions src/cmd/openmldb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#endif
#include "apiserver/api_server_impl.h"
#include "auth/brpc_authenticator.h"
#include "auth/user_access_manager.h"
#include "boost/algorithm/string.hpp"
#include "boost/lexical_cast.hpp"
#include "brpc/server.h"
Expand Down Expand Up @@ -147,12 +146,10 @@ void StartNameServer() {
}

brpc::ServerOptions options;
std::unique_ptr<openmldb::auth::UserAccessManager> user_access_manager;
std::unique_ptr<openmldb::authn::BRPCAuthenticator> server_authenticator;
user_access_manager = std::make_unique<openmldb::auth::UserAccessManager>(name_server->GetSystemTableIterator());
server_authenticator = std::make_unique<openmldb::authn::BRPCAuthenticator>(
[&user_access_manager](const std::string& host, const std::string& username, const std::string& password) {
return user_access_manager->IsAuthenticated(host, username, password);
[name_server](const std::string& host, const std::string& username, const std::string& password) {
return name_server->IsAuthenticated(host, username, password);
});
options.auth = server_authenticator.get();

Expand Down Expand Up @@ -253,13 +250,11 @@ void StartTablet() {
exit(1);
}
brpc::ServerOptions options;
std::unique_ptr<openmldb::auth::UserAccessManager> user_access_manager;
std::unique_ptr<openmldb::authn::BRPCAuthenticator> server_authenticator;

user_access_manager = std::make_unique<openmldb::auth::UserAccessManager>(tablet->GetSystemTableIterator());
server_authenticator = std::make_unique<openmldb::authn::BRPCAuthenticator>(
[&user_access_manager](const std::string& host, const std::string& username, const std::string& password) {
return user_access_manager->IsAuthenticated(host, username, password);
[tablet](const std::string& host, const std::string& username, const std::string& password) {
return tablet->IsAuthenticated(host, username, password);
});
options.auth = server_authenticator.get();
options.num_threads = FLAGS_thread_pool_size;
Expand Down
3 changes: 0 additions & 3 deletions src/cmd/sql_cmd_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ TEST_P(DBSDKTest, TestUser) {
ASSERT_TRUE(status.IsOK());
ASSERT_TRUE(true);
auto opt = sr->GetRouterOptions();
std::this_thread::sleep_for(std::chrono::seconds(1)); // TODO(oh2024): Remove when CREATE USER becomes strongly
if (cs->IsClusterMode()) {
auto real_opt = std::dynamic_pointer_cast<sdk::SQLRouterOptions>(opt);
sdk::SQLRouterOptions opt1;
Expand All @@ -257,7 +256,6 @@ TEST_P(DBSDKTest, TestUser) {
ASSERT_TRUE(router != nullptr);
sr->ExecuteSQL(absl::StrCat("ALTER USER user1 SET OPTIONS(password='abc')"), &status);
ASSERT_TRUE(status.IsOK());
std::this_thread::sleep_for(std::chrono::seconds(1)); // TODO(oh2024): Remove when CREATE USER becomes strongly
router = NewClusterSQLRouter(opt1);
ASSERT_FALSE(router != nullptr);
} else {
Expand All @@ -271,7 +269,6 @@ TEST_P(DBSDKTest, TestUser) {
ASSERT_TRUE(router != nullptr);
sr->ExecuteSQL(absl::StrCat("ALTER USER user1 SET OPTIONS(password='abc')"), &status);
ASSERT_TRUE(status.IsOK());
std::this_thread::sleep_for(std::chrono::seconds(1)); // TODO(oh2024): Remove when CREATE USER becomes strongly
router = NewStandaloneSQLRouter(opt1);
ASSERT_FALSE(router != nullptr);
}
Expand Down
92 changes: 79 additions & 13 deletions src/nameserver/name_server_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "boost/bind.hpp"
#include "codec/row_codec.h"
#include "gflags/gflags.h"
#include "name_server_impl.h"
#include "schema/index_util.h"
#include "schema/schema_adapter.h"

Expand Down Expand Up @@ -522,7 +523,8 @@ NameServerImpl::NameServerImpl()
thread_pool_(1),
task_thread_pool_(FLAGS_name_server_task_pool_size),
rand_(0xdeadbeef),
startup_mode_(::openmldb::type::StartupMode::kStandalone) {}
startup_mode_(::openmldb::type::StartupMode::kStandalone),
user_access_manager_(GetSystemTableIterator()) {}

NameServerImpl::~NameServerImpl() {
running_.store(false, std::memory_order_release);
Expand Down Expand Up @@ -650,7 +652,7 @@ bool NameServerImpl::Recover() {
if (!RecoverExternalFunction()) {
return false;
}
return true;
return FlushPrivileges().OK();
}

bool NameServerImpl::RecoverExternalFunction() {
Expand Down Expand Up @@ -1377,8 +1379,8 @@ void NameServerImpl::ShowTablet(RpcController* controller, const ShowTabletReque
response->set_msg("ok");
}

base::Status NameServerImpl::InsertUserRecord(const std::string& host, const std::string& user,
const std::string& password) {
base::Status NameServerImpl::PutUserRecord(const std::string& host, const std::string& user,
const std::string& password) {
std::shared_ptr<TableInfo> table_info;
if (!GetTableInfo(USER_INFO_NAME, INTERNAL_DB, &table_info)) {
return {ReturnCode::kTableIsNotExist, "user table does not exist"};
Expand All @@ -1388,13 +1390,13 @@ base::Status NameServerImpl::InsertUserRecord(const std::string& host, const std
row_values.push_back(host);
row_values.push_back(user);
row_values.push_back(password);
row_values.push_back(""); // password_last_changed
row_values.push_back(""); // password_expired_time
row_values.push_back(""); // create_time
row_values.push_back(""); // update_time
row_values.push_back(""); // account_type
row_values.push_back(""); // privileges
row_values.push_back(""); // extra_info
row_values.push_back("0"); // password_last_changed
row_values.push_back("0"); // password_expired_time
row_values.push_back("0"); // create_time
row_values.push_back("0"); // update_time
row_values.push_back("1"); // account_type
row_values.push_back("0"); // privileges
row_values.push_back("null"); // extra_info

std::string encoded_row;
codec::RowCodec::EncodeRow(row_values, table_info->column_desc(), 1, encoded_row);
Expand All @@ -1410,11 +1412,56 @@ base::Status NameServerImpl::InsertUserRecord(const std::string& host, const std
std::string endpoint = table_partition.partition_meta(meta_idx).endpoint();
auto table_ptr = GetTablet(endpoint);
if (!table_ptr->client_->Put(tid, 0, cur_ts, encoded_row, dimensions).OK()) {
return {ReturnCode::kPutFailed, "failed to create initial user entry"};
return {ReturnCode::kPutFailed, "failed to put user entry"};
}
break;
}
}
return FlushPrivileges();
}

base::Status NameServerImpl::DeleteUserRecord(const std::string& host, const std::string& user) {
std::shared_ptr<TableInfo> table_info;
if (!GetTableInfo(USER_INFO_NAME, INTERNAL_DB, &table_info)) {
return {ReturnCode::kTableIsNotExist, "user table does not exist"};
}
uint32_t tid = table_info->tid();
auto table_partition = table_info->table_partition(0); // only one partition for system table
std::string msg;
for (int meta_idx = 0; meta_idx < table_partition.partition_meta_size(); meta_idx++) {
if (table_partition.partition_meta(meta_idx).is_leader() &&
table_partition.partition_meta(meta_idx).is_alive()) {
uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000;
std::string endpoint = table_partition.partition_meta(meta_idx).endpoint();
auto table_ptr = GetTablet(endpoint);
if (!table_ptr->client_->Delete(tid, 0, host + "|" + user, "index", msg)) {
return {ReturnCode::kDeleteFailed, msg};
}

break;
}
}
return FlushPrivileges();
}

base::Status NameServerImpl::FlushPrivileges() {
user_access_manager_.SyncWithDB();
std::vector<std::string> failed_tablet_list;
{
std::lock_guard<std::mutex> lock(mu_);
for (const auto& tablet_pair : tablets_) {
const std::shared_ptr<TabletInfo>& tablet_info = tablet_pair.second;
if (tablet_info && tablet_info->Health() && tablet_info->client_) {
if (!tablet_info->client_->FlushPrivileges()) {
failed_tablet_list.push_back(tablet_pair.first);
}
}
}
}
if (failed_tablet_list.size() > 0) {
return {ReturnCode::kFlushPrivilegesFailed,
"Failed to flush privileges to tablets: " + boost::algorithm::join(failed_tablet_list, ", ")};
}
return {};
}

Expand Down Expand Up @@ -5593,7 +5640,7 @@ void NameServerImpl::OnLocked() {
CreateDatabaseOrExit(INTERNAL_DB);
if (db_table_info_[INTERNAL_DB].count(USER_INFO_NAME) == 0) {
CreateSystemTableOrExit(SystemTableType::kUser);
InsertUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
PutUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
}
if (IsClusterMode()) {
if (tablets_.size() < FLAGS_system_table_replica_num) {
Expand Down Expand Up @@ -9613,6 +9660,25 @@ NameServerImpl::GetSystemTableIterator() {
};
}

void NameServerImpl::PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response,
Closure* done) {
brpc::ClosureGuard done_guard(done);
auto status = PutUserRecord(request->host(), request->name(), request->password());
base::SetResponseStatus(status, response);
}

void NameServerImpl::DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response,
Closure* done) {
brpc::ClosureGuard done_guard(done);
auto status = DeleteUserRecord(request->host(), request->name());
base::SetResponseStatus(status, response);
}

bool NameServerImpl::IsAuthenticated(const std::string& host, const std::string& username,
const std::string& password) {
return user_access_manager_.IsAuthenticated(host, username, password);
}

bool NameServerImpl::RecoverProcedureInfo() {
db_table_sp_map_.clear();
db_sp_table_map_.clear();
Expand Down
Loading

0 comments on commit 21184d5

Please sign in to comment.