diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index da7d884fc9ca..6ea181f6232c 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -573,6 +573,10 @@ Connection::~Connection() { UpdateLibNameVerMap(lib_name_, lib_ver_, -1); } +bool Connection::IsSending() const { + return reply_builder_ && reply_builder_->IsSendActive(); +} + // Called from Connection::Shutdown() right after socket_->Shutdown call. void Connection::OnShutdown() { VLOG(1) << "Connection::OnShutdown"; @@ -1617,9 +1621,6 @@ bool Connection::Migrate(util::fb2::ProactorBase* dest) { Connection::WeakRef Connection::Borrow() { DCHECK(self_); - // If the connection is unaware of subscriptions, it could migrate threads, making this call - // unsafe. All external mechanisms that borrow references should register subscriptions. - DCHECK_GT(cc_->subscriptions, 0); return WeakRef(self_, queue_backpressure_, socket_->proactor()->GetPoolIndex(), id_); } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 3eced43ae1b4..0cd6082d4907 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -311,6 +311,16 @@ class Connection : public util::Connection { static void GetRequestSizeHistogramThreadLocal(std::string* hist); static void TrackRequestSize(bool enable); + unsigned idle_time() const { + return time(nullptr) - last_interaction_; + } + + Phase phase() const { + return phase_; + } + + bool IsSending() const; + protected: void OnShutdown() override; void OnPreMigrateThread() override; diff --git a/src/server/blocking_controller_test.cc b/src/server/blocking_controller_test.cc index 9f133c891d0a..1e0b74202a83 100644 --- a/src/server/blocking_controller_test.cc +++ b/src/server/blocking_controller_test.cc @@ -29,8 +29,9 @@ class BlockingControllerTest : public Test { } void SetUp() override; void TearDown() override; + static void SetUpTestSuite() { - ServerState::Init(kNumThreads, kNumThreads, nullptr); + ServerState::Init(kNumThreads, kNumThreads, nullptr, nullptr); facade::tl_facade_stats = new facade::FacadeStats; } @@ -45,7 +46,7 @@ void BlockingControllerTest::SetUp() { pp_.reset(fb2::Pool::Epoll(kNumThreads)); pp_->Run(); pp_->AwaitBrief([](unsigned index, ProactorBase* p) { - ServerState::Init(index, kNumThreads, nullptr); + ServerState::Init(index, kNumThreads, nullptr, nullptr); if (facade::tl_facade_stats == nullptr) { facade::tl_facade_stats = new facade::FacadeStats; } diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 2af84299892d..a8742e2e1045 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -813,6 +813,7 @@ void Service::Init(util::AcceptServer* acceptor, std::vector config_registry.RegisterMutable("migration_finalization_timeout_ms"); config_registry.RegisterMutable("table_growth_margin"); config_registry.RegisterMutable("tcp_keepalive"); + config_registry.RegisterMutable("timeout"); config_registry.RegisterMutable("managed_service_info"); config_registry.RegisterMutable( @@ -848,19 +849,23 @@ void Service::Init(util::AcceptServer* acceptor, std::vector shard_num = pp_.size(); } + // We assume that listeners.front() is the main_listener + // see dfly_main RunEngine. In unit tests, listeners are empty. + facade::Listener* main_listener = listeners.empty() ? nullptr : listeners.front(); + ChannelStore* cs = new ChannelStore{}; // Must initialize before the shard_set because EngineShard::Init references ServerState. pp_.AwaitBrief([&](uint32_t index, ProactorBase* pb) { tl_facade_stats = new FacadeStats; - ServerState::Init(index, shard_num, &user_registry_); + ServerState::Init(index, shard_num, main_listener, &user_registry_); ServerState::tlocal()->UpdateChannelStore(cs); }); const auto tcp_disabled = GetFlag(FLAGS_port) == 0u; // We assume that listeners.front() is the main_listener // see dfly_main RunEngine - if (!tcp_disabled && !listeners.empty()) { - acl_family_.Init(listeners.front(), &user_registry_); + if (!tcp_disabled && main_listener) { + acl_family_.Init(main_listener, &user_registry_); } // Initialize shard_set with a callback running once in a while in the shard threads. @@ -906,7 +911,7 @@ void Service::Shutdown() { shard_set->Shutdown(); Transaction::Shutdown(); - pp_.Await([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); }); + pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); }); // wait for all the pending callbacks to stop. ThisFiber::SleepFor(10ms); diff --git a/src/server/server_state.cc b/src/server/server_state.cc index 9bdedc1a7d28..e0db9126ddea 100644 --- a/src/server/server_state.cc +++ b/src/server/server_state.cc @@ -15,12 +15,19 @@ extern "C" { #include "base/flags.h" #include "base/logging.h" #include "facade/conn_context.h" +#include "facade/dragonfly_connection.h" #include "server/journal/journal.h" +#include "util/listener_interface.h" ABSL_FLAG(uint32_t, interpreter_per_thread, 10, "Lua interpreters per thread"); +ABSL_FLAG(uint32_t, timeout, 0, + "Close the connection after it is idle for N seconds (0 to disable)"); namespace dfly { +using namespace std; +using namespace std::chrono_literals; + __thread ServerState* ServerState::state_ = nullptr; ServerState::Stats::Stats(unsigned num_shards) : tx_width_freq_arr(num_shards) { @@ -102,14 +109,21 @@ ServerState::ServerState() : interpreter_mgr_{absl::GetFlag(FLAGS_interpreter_pe } ServerState::~ServerState() { + watcher_fiber_.JoinIfNeeded(); } -void ServerState::Init(uint32_t thread_index, uint32_t num_shards, acl::UserRegistry* registry) { +void ServerState::Init(uint32_t thread_index, uint32_t num_shards, + util::ListenerInterface* main_listener, acl::UserRegistry* registry) { state_ = new ServerState(); state_->gstate_ = GlobalState::ACTIVE; state_->thread_index_ = thread_index; state_->user_registry = registry; state_->stats = Stats(num_shards); + if (main_listener) { + state_->watcher_fiber_ = util::fb2::Fiber( + util::fb2::Launch::post, "ConnectionsWatcher", + [state = state_, main_listener] { state->ConnectionsWatcherFb(main_listener); }); + } } void ServerState::Destroy() { @@ -117,6 +131,11 @@ void ServerState::Destroy() { state_ = nullptr; } +void ServerState::EnterLameDuck() { + gstate_ = GlobalState::SHUTTING_DOWN; + watcher_cv_.notify_all(); +} + ServerState::MemoryUsageStats ServerState::GetMemoryUsage(uint64_t now_ns) { static constexpr uint64_t kCacheEveryNs = 1000; if (now_ns > used_mem_last_update_ + kCacheEveryNs) { @@ -208,4 +227,46 @@ ServerState* ServerState::SafeTLocal() { bool ServerState::ShouldLogSlowCmd(unsigned latency_usec) const { return slow_log_shard_.IsEnabled() && latency_usec >= log_slower_than_usec; } + +void ServerState::ConnectionsWatcherFb(util::ListenerInterface* main) { + while (true) { + util::fb2::NoOpLock noop; + if (watcher_cv_.wait_for(noop, 1s, [this] { return gstate_ == GlobalState::SHUTTING_DOWN; })) { + break; + } + + uint32_t timeout = absl::GetFlag(FLAGS_timeout); + if (timeout == 0) { + continue; + } + + // We use weak refs, because ShutdownSelf below can potentially block the fiber, + // and during this time some of the connections might be destroyed. Weak refs allow checking + // validity of each connection. + vector conn_refs; + + auto cb = [&](unsigned thread_index, util::Connection* conn) { + facade::Connection* dfly_conn = static_cast(conn); + using Phase = facade::Connection::Phase; + auto phase = dfly_conn->phase(); + if ((phase == Phase::READ_SOCKET || dfly_conn->IsSending()) && + dfly_conn->idle_time() > timeout) { + conn_refs.push_back(dfly_conn->Borrow()); + } + }; + + // TODO: to traverse in batches some of the connections to avoid blocking + // the thread for too long + main->TraverseConnectionsOnThread(cb); + + for (auto& ref : conn_refs) { + facade::Connection* conn = ref.Get(); + if (conn) { + VLOG(1) << "Closing connection due to timeout: " << conn->GetClientInfo(); + conn->ShutdownSelf(); + } + } + } +} + } // end of namespace dfly diff --git a/src/server/server_state.h b/src/server/server_state.h index 6ea43787f48f..044cd6774c00 100644 --- a/src/server/server_state.h +++ b/src/server/server_state.h @@ -23,6 +23,10 @@ namespace facade { class Connection; } +namespace util { +class ListenerInterface; +} + namespace dfly { namespace journal { @@ -150,12 +154,11 @@ class ServerState { // public struct - to allow initialization. ServerState(); ~ServerState(); - static void Init(uint32_t thread_index, uint32_t num_shards, acl::UserRegistry* registry); + static void Init(uint32_t thread_index, uint32_t num_shards, + util::ListenerInterface* main_listener, acl::UserRegistry* registry); static void Destroy(); - void EnterLameDuck() { - state_->gstate_ = GlobalState::SHUTTING_DOWN; - } + void EnterLameDuck(); void TxCountInc() { ++live_transactions_; @@ -302,6 +305,9 @@ class ServerState { // public struct - to allow initialization. size_t serialization_max_chunk_size; private: + // A fiber constantly watching connections on the main listener. + void ConnectionsWatcherFb(util::ListenerInterface* main); + int64_t live_transactions_ = 0; SlowLogShard slow_log_shard_; mi_heap_t* data_heap_; @@ -321,6 +327,10 @@ class ServerState { // public struct - to allow initialization. int client_pauses_[2] = {}; util::fb2::EventCount client_pause_ec_; + // Monitors connections. Currently responsible for closing timed out connections. + util::fb2::Fiber watcher_fiber_; + util::fb2::CondVarAny watcher_cv_; + using Counter = util::SlidingCounter<7>; Counter qps_; diff --git a/tests/dragonfly/connection_test.py b/tests/dragonfly/connection_test.py index 6b6eb1715463..868e89fe6337 100755 --- a/tests/dragonfly/connection_test.py +++ b/tests/dragonfly/connection_test.py @@ -1038,3 +1038,14 @@ async def test_lib_name_ver(async_client: aioredis.Redis): assert len(list) == 1 assert list[0]["lib-name"] == "dragonfly" assert list[0]["lib-ver"] == "1.2.3.4" + + +@dfly_args({"timeout": 1}) +async def test_timeout(df_server: DflyInstance, async_client: aioredis.Redis): + another_client = df_server.client() + await another_client.ping() + clients = await async_client.client_list() + assert len(clients) == 2 + await asyncio.sleep(2) + clients = await async_client.client_list() + assert len(clients) == 1