Skip to content

Commit

Permalink
feat: implement connection timeout functionality
Browse files Browse the repository at this point in the history
`timeout` argument shuts down idle connections after the specified time.

Fixes #1677

Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange committed Jan 6, 2025
1 parent f663f8e commit 799158b
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 16 deletions.
7 changes: 4 additions & 3 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,10 @@ Connection::~Connection() {
UpdateLibNameVerMap(lib_name_, lib_ver_, -1);
}

bool Connection::IsSending() const {
return reply_builder_ && reply_builder_->IsSendActive();
}

// Called from Connection::Shutdown() right after socket_->Shutdown call.
void Connection::OnShutdown() {
VLOG(1) << "Connection::OnShutdown";
Expand Down Expand Up @@ -1617,9 +1621,6 @@ bool Connection::Migrate(util::fb2::ProactorBase* dest) {

Connection::WeakRef Connection::Borrow() {
DCHECK(self_);
// If the connection is unaware of subscriptions, it could migrate threads, making this call
// unsafe. All external mechanisms that borrow references should register subscriptions.
DCHECK_GT(cc_->subscriptions, 0);

return WeakRef(self_, queue_backpressure_, socket_->proactor()->GetPoolIndex(), id_);
}
Expand Down
10 changes: 10 additions & 0 deletions src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ class Connection : public util::Connection {
static void GetRequestSizeHistogramThreadLocal(std::string* hist);
static void TrackRequestSize(bool enable);

unsigned idle_time() const {
return time(nullptr) - last_interaction_;
}

Phase phase() const {
return phase_;
}

bool IsSending() const;

protected:
void OnShutdown() override;
void OnPreMigrateThread() override;
Expand Down
5 changes: 3 additions & 2 deletions src/server/blocking_controller_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ class BlockingControllerTest : public Test {
}
void SetUp() override;
void TearDown() override;

static void SetUpTestSuite() {
ServerState::Init(kNumThreads, kNumThreads, nullptr);
ServerState::Init(kNumThreads, kNumThreads, nullptr, nullptr);
facade::tl_facade_stats = new facade::FacadeStats;
}

Expand All @@ -45,7 +46,7 @@ void BlockingControllerTest::SetUp() {
pp_.reset(fb2::Pool::Epoll(kNumThreads));
pp_->Run();
pp_->AwaitBrief([](unsigned index, ProactorBase* p) {
ServerState::Init(index, kNumThreads, nullptr);
ServerState::Init(index, kNumThreads, nullptr, nullptr);
if (facade::tl_facade_stats == nullptr) {
facade::tl_facade_stats = new facade::FacadeStats;
}
Expand Down
18 changes: 12 additions & 6 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
config_registry.RegisterMutable("migration_finalization_timeout_ms");
config_registry.RegisterMutable("table_growth_margin");
config_registry.RegisterMutable("tcp_keepalive");
config_registry.RegisterMutable("timeout");
config_registry.RegisterMutable("managed_service_info");

config_registry.RegisterMutable(
Expand Down Expand Up @@ -851,19 +852,24 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
shard_num = pp_.size();
}

CHECK(!listeners.empty());

// We assume that listeners.front() is the main_listener
// see dfly_main RunEngine
facade::Listener* main_listener = listeners.front();

ChannelStore* cs = new ChannelStore{};
// Must initialize before the shard_set because EngineShard::Init references ServerState.
pp_.AwaitBrief([&](uint32_t index, ProactorBase* pb) {
tl_facade_stats = new FacadeStats;
ServerState::Init(index, shard_num, &user_registry_);
ServerState::Init(index, shard_num, main_listener, &user_registry_);
ServerState::tlocal()->UpdateChannelStore(cs);
});

const auto tcp_disabled = GetFlag(FLAGS_port) == 0u;
// We assume that listeners.front() is the main_listener
// see dfly_main RunEngine
if (!tcp_disabled && !listeners.empty()) {
acl_family_.Init(listeners.front(), &user_registry_);

if (!tcp_disabled) {
acl_family_.Init(main_listener, &user_registry_);
}

// Initialize shard_set with a callback running once in a while in the shard threads.
Expand Down Expand Up @@ -905,7 +911,7 @@ void Service::Shutdown() {
shard_set->Shutdown();
Transaction::Shutdown();

pp_.Await([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); });
pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); });

// wait for all the pending callbacks to stop.
ThisFiber::SleepFor(10ms);
Expand Down
63 changes: 62 additions & 1 deletion src/server/server_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@ extern "C" {
#include "base/flags.h"
#include "base/logging.h"
#include "facade/conn_context.h"
#include "facade/dragonfly_connection.h"
#include "server/journal/journal.h"
#include "util/listener_interface.h"

ABSL_FLAG(uint32_t, interpreter_per_thread, 10, "Lua interpreters per thread");
ABSL_FLAG(uint32_t, timeout, 0,
"Close the connection after it is idle for N seconds (0 to disable)");

namespace dfly {

using namespace std;
using namespace std::chrono_literals;

__thread ServerState* ServerState::state_ = nullptr;

ServerState::Stats::Stats(unsigned num_shards) : tx_width_freq_arr(num_shards) {
Expand Down Expand Up @@ -102,21 +109,33 @@ ServerState::ServerState() : interpreter_mgr_{absl::GetFlag(FLAGS_interpreter_pe
}

ServerState::~ServerState() {
watcher_fiber_.JoinIfNeeded();
}

void ServerState::Init(uint32_t thread_index, uint32_t num_shards, acl::UserRegistry* registry) {
void ServerState::Init(uint32_t thread_index, uint32_t num_shards,
util::ListenerInterface* main_listener, acl::UserRegistry* registry) {
state_ = new ServerState();
state_->gstate_ = GlobalState::ACTIVE;
state_->thread_index_ = thread_index;
state_->user_registry = registry;
state_->stats = Stats(num_shards);
if (main_listener) {
state_->watcher_fiber_ = util::fb2::Fiber(
util::fb2::Launch::post, "ConnectionsWatcher",
[state = state_, main_listener] { state->ConnectionsWatcherFb(main_listener); });
}
}

void ServerState::Destroy() {
delete state_;
state_ = nullptr;
}

void ServerState::EnterLameDuck() {
gstate_ = GlobalState::SHUTTING_DOWN;
watcher_cv_.notify_all();
}

ServerState::MemoryUsageStats ServerState::GetMemoryUsage(uint64_t now_ns) {
static constexpr uint64_t kCacheEveryNs = 1000;
if (now_ns > used_mem_last_update_ + kCacheEveryNs) {
Expand Down Expand Up @@ -208,4 +227,46 @@ ServerState* ServerState::SafeTLocal() {
bool ServerState::ShouldLogSlowCmd(unsigned latency_usec) const {
return slow_log_shard_.IsEnabled() && latency_usec >= log_slower_than_usec;
}

void ServerState::ConnectionsWatcherFb(util::ListenerInterface* main) {
while (true) {
util::fb2::NoOpLock noop;
if (watcher_cv_.wait_for(noop, 1s, [this] { return gstate_ == GlobalState::SHUTTING_DOWN; })) {
break;
}

uint32_t timeout = absl::GetFlag(FLAGS_timeout);
if (timeout == 0) {
continue;
}

// We use weak refs, because ShutdownSelf below can potentially block the fiber,
// and during this time some of the connections might be destroyed. Weak refs allow checking
// validity of each connection.
vector<facade::Connection::WeakRef> conn_refs;

auto cb = [&](unsigned thread_index, util::Connection* conn) {
facade::Connection* dfly_conn = static_cast<facade::Connection*>(conn);
using Phase = facade::Connection::Phase;
auto phase = dfly_conn->phase();
if ((phase == Phase::READ_SOCKET || dfly_conn->IsSending()) &&
dfly_conn->idle_time() > timeout) {
conn_refs.push_back(dfly_conn->Borrow());
}
};

// TODO: to traverse in batches some of the connections to avoid blocking
// the thread for too long
main->TraverseConnectionsOnThread(cb);

for (auto& ref : conn_refs) {
facade::Connection* conn = ref.Get();
if (conn) {
VLOG(1) << "Closing connection due to timeout: " << conn->GetClientInfo();
conn->ShutdownSelf();
}
}
}
}

} // end of namespace dfly
18 changes: 14 additions & 4 deletions src/server/server_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ namespace facade {
class Connection;
}

namespace util {
class ListenerInterface;
}

namespace dfly {

namespace journal {
Expand Down Expand Up @@ -150,12 +154,11 @@ class ServerState { // public struct - to allow initialization.
ServerState();
~ServerState();

static void Init(uint32_t thread_index, uint32_t num_shards, acl::UserRegistry* registry);
static void Init(uint32_t thread_index, uint32_t num_shards,
util::ListenerInterface* main_listener, acl::UserRegistry* registry);
static void Destroy();

void EnterLameDuck() {
state_->gstate_ = GlobalState::SHUTTING_DOWN;
}
void EnterLameDuck();

void TxCountInc() {
++live_transactions_;
Expand Down Expand Up @@ -302,6 +305,9 @@ class ServerState { // public struct - to allow initialization.
size_t serialization_max_chunk_size;

private:
// A fiber constantly watching connections on the main listener.
void ConnectionsWatcherFb(util::ListenerInterface* main);

int64_t live_transactions_ = 0;
SlowLogShard slow_log_shard_;
mi_heap_t* data_heap_;
Expand All @@ -321,6 +327,10 @@ class ServerState { // public struct - to allow initialization.
int client_pauses_[2] = {};
util::fb2::EventCount client_pause_ec_;

// Monitors connections. Currently responsible for closing timed out connections.
util::fb2::Fiber watcher_fiber_;
util::fb2::CondVarAny watcher_cv_;

using Counter = util::SlidingCounter<7>;
Counter qps_;

Expand Down
11 changes: 11 additions & 0 deletions tests/dragonfly/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,3 +1038,14 @@ async def test_lib_name_ver(async_client: aioredis.Redis):
assert len(list) == 1
assert list[0]["lib-name"] == "dragonfly"
assert list[0]["lib-ver"] == "1.2.3.4"


@dfly_args({"timeout": 1})
async def test_timeout(df_server: DflyInstance, async_client: aioredis.Redis):
another_client = df_server.client()
await another_client.ping()
clients = await async_client.client_list()
assert len(clients) == 2
await asyncio.sleep(2)
clients = await async_client.client_list()
assert len(clients) == 1

0 comments on commit 799158b

Please sign in to comment.