diff --git a/src/engine/Server.cpp b/src/engine/Server.cpp index f641072d1a..95fce8a570 100644 --- a/src/engine/Server.cpp +++ b/src/engine/Server.cpp @@ -320,7 +320,11 @@ Awaitable Server::process( } else if (auto cmd = checkParameter("cmd", "dump-active-queries", accessTokenOk)) { logCommand(cmd, "dump active queries"); - response = createJsonResponse(queryRegistry_.getActiveQueries(), request); + nlohmann::json json; + for (auto& [key, value] : queryRegistry_.getActiveQueries()) { + json[nlohmann::json(key)] = std::move(value); + } + response = createJsonResponse(json, request); } // Ping with or without messsage. @@ -488,13 +492,15 @@ class QueryAlreadyInUseError : public std::runtime_error { // _____________________________________________ ad_utility::websocket::OwningQueryId Server::getQueryId( - const ad_utility::httpUtils::HttpRequest auto& request) { + const ad_utility::httpUtils::HttpRequest auto& request, + std::string_view query) { using ad_utility::websocket::OwningQueryId; std::string_view queryIdHeader = request.base()["Query-Id"]; if (queryIdHeader.empty()) { - return queryRegistry_.uniqueId(); + return queryRegistry_.uniqueId(query); } - auto queryId = queryRegistry_.uniqueIdFromString(std::string(queryIdHeader)); + auto queryId = + queryRegistry_.uniqueIdFromString(std::string(queryIdHeader), query); if (!queryId) { throw QueryAlreadyInUseError{queryIdHeader}; } @@ -637,7 +643,7 @@ boost::asio::awaitable Server::processQuery( auto queryHub = queryHub_.lock(); AD_CORRECTNESS_CHECK(queryHub); auto messageSender = co_await ad_utility::websocket::MessageSender::create( - getQueryId(request), *queryHub); + getQueryId(request, query), *queryHub); // Do the query planning. This creates a `QueryExecutionTree`, which will // then be used to process the query. // diff --git a/src/engine/Server.h b/src/engine/Server.h index 8356138689..b559e8d7d3 100644 --- a/src/engine/Server.h +++ b/src/engine/Server.h @@ -156,11 +156,13 @@ class Server { /// `QueryAlreadyInUseError` exception is thrown. /// /// \param request The HTTP request to extract the id from. + /// \param query A string representation of the query to register an id for. /// /// \return An OwningQueryId object. It removes itself from the registry /// on destruction. ad_utility::websocket::OwningQueryId getQueryId( - const ad_utility::httpUtils::HttpRequest auto& request); + const ad_utility::httpUtils::HttpRequest auto& request, + std::string_view query); /// Schedule a task to trigger the timeout after the `timeLimit`. /// The returned callback can be used to prevent this task from executing diff --git a/src/util/http/websocket/QueryId.h b/src/util/http/websocket/QueryId.h index d6ce5ca015..65cf2241d1 100644 --- a/src/util/http/websocket/QueryId.h +++ b/src/util/http/websocket/QueryId.h @@ -70,8 +70,15 @@ static_assert(!std::is_copy_assignable_v); /// A factory class to create unique query ids within each individual instance. class QueryRegistry { + struct CancellationHandleWithQuery { + SharedCancellationHandle cancellationHandle_ = + std::make_shared>(); + std::string query_; + explicit CancellationHandleWithQuery(std::string_view query) + : query_{query} {} + }; using SynchronizedType = ad_utility::Synchronized< - ad_utility::HashMap>; + ad_utility::HashMap>; // Technically no shared pointer is required because the registry lives // for the entire lifetime of the application, but since the instances of // `OwningQueryId` need to deregister themselves again they need some @@ -84,15 +91,14 @@ class QueryRegistry { /// Tries to create a new unique OwningQueryId object from the given string. /// \param id The id representation of the potential candidate. + /// \param query The string representation of the associated SPARQL query. /// \return A std::optional object wrapping the passed string /// if it was not present in the registry before. An empty /// std::optional if the id already existed before. - std::optional uniqueIdFromString(std::string id) { + std::optional uniqueIdFromString(std::string id, + std::string_view query) { auto queryId = QueryId::idFromString(std::move(id)); - bool success = - registry_->wlock() - ->emplace(queryId, std::make_shared>()) - .second; + bool success = registry_->wlock()->try_emplace(queryId, query).second; if (success) { // Avoid undefined behavior when the registry is no longer alive at the // time the `OwningQueryId` is destroyed. @@ -111,25 +117,25 @@ class QueryRegistry { } /// Generates a unique pseudo-random OwningQueryId object for this registry - OwningQueryId uniqueId() { + /// and associates it with the given query. + OwningQueryId uniqueId(std::string_view query) { static thread_local std::mt19937 generator(std::random_device{}()); std::uniform_int_distribution distrib{}; std::optional result; do { - result = uniqueIdFromString(std::to_string(distrib(generator))); + result = uniqueIdFromString(std::to_string(distrib(generator)), query); } while (!result.has_value()); return std::move(result.value()); } /// Member function that acquires a read lock and returns a vector /// of all currently registered queries. - std::vector getActiveQueries() const { + ad_utility::HashMap getActiveQueries() const { return registry_->withReadLock([](const auto& map) { - // TODO Use `ranges::to` to transform map keys into vector - std::vector result; + ad_utility::HashMap result; result.reserve(map.size()); for (const auto& entry : map) { - result.emplace_back(entry.first); + result.emplace(entry.first, entry.second.query_); } return result; }); @@ -140,7 +146,7 @@ class QueryRegistry { SharedCancellationHandle getCancellationHandle(const QueryId& queryId) const { return registry_->withReadLock([&queryId](const auto& map) { auto it = map.find(queryId); - return it != map.end() ? it->second : nullptr; + return it != map.end() ? it->second.cancellationHandle_ : nullptr; }); } }; diff --git a/test/MessageSenderTest.cpp b/test/MessageSenderTest.cpp index 11984986ef..ae7e5ba2a3 100644 --- a/test/MessageSenderTest.cpp +++ b/test/MessageSenderTest.cpp @@ -25,7 +25,7 @@ using ::testing::VariantWith; ASYNC_TEST(MessageSender, destructorCallsSignalEnd) { QueryRegistry queryRegistry; - OwningQueryId queryId = queryRegistry.uniqueId(); + OwningQueryId queryId = queryRegistry.uniqueId("my-query"); QueryHub queryHub{ioContext}; auto distributor = co_await queryHub.createOrAcquireDistributorForReceiving( @@ -47,7 +47,7 @@ ASYNC_TEST(MessageSender, destructorCallsSignalEnd) { ASYNC_TEST(MessageSender, callingOperatorBroadcastsPayload) { QueryRegistry queryRegistry; - OwningQueryId queryId = queryRegistry.uniqueId(); + OwningQueryId queryId = queryRegistry.uniqueId("my-query"); QueryHub queryHub{ioContext}; { @@ -85,7 +85,7 @@ ASYNC_TEST(MessageSender, callingOperatorBroadcastsPayload) { ASYNC_TEST(MessageSender, testGetQueryIdGetterWorks) { QueryRegistry queryRegistry; - OwningQueryId queryId = queryRegistry.uniqueId(); + OwningQueryId queryId = queryRegistry.uniqueId("my-query"); QueryId reference = queryId.toQueryId(); QueryHub queryHub{ioContext}; diff --git a/test/QueryIdTest.cpp b/test/QueryIdTest.cpp index 7ea1d1a20a..b8118bdb92 100644 --- a/test/QueryIdTest.cpp +++ b/test/QueryIdTest.cpp @@ -9,9 +9,8 @@ using ad_utility::websocket::OwningQueryId; using ad_utility::websocket::QueryId; using ad_utility::websocket::QueryRegistry; -using ::testing::ElementsAre; +using ::testing::ContainerEq; using ::testing::IsEmpty; -using ::testing::UnorderedElementsAre; TEST(QueryId, checkIdEqualityRelation) { auto queryIdOne = QueryId::idFromString("some-id"); @@ -53,8 +52,8 @@ TEST(QueryId, veriyToJsonWorks) { TEST(QueryRegistry, verifyUniqueIdProvidesUniqueIds) { QueryRegistry registry{}; - auto queryIdOne = registry.uniqueId(); - auto queryIdTwo = registry.uniqueId(); + auto queryIdOne = registry.uniqueId("my-query"); + auto queryIdTwo = registry.uniqueId("my-query"); EXPECT_NE(queryIdOne.toQueryId(), queryIdTwo.toQueryId()); } @@ -63,8 +62,10 @@ TEST(QueryRegistry, verifyUniqueIdProvidesUniqueIds) { TEST(QueryRegistry, verifyUniqueIdFromStringEnforcesUniqueness) { QueryRegistry registry{}; - auto optionalQueryIdOne = registry.uniqueIdFromString("01123581321345589144"); - auto optionalQueryIdTwo = registry.uniqueIdFromString("01123581321345589144"); + auto optionalQueryIdOne = + registry.uniqueIdFromString("01123581321345589144", "my-query"); + auto optionalQueryIdTwo = + registry.uniqueIdFromString("01123581321345589144", "my-query"); EXPECT_TRUE(optionalQueryIdOne.has_value()); EXPECT_FALSE(optionalQueryIdTwo.has_value()); @@ -75,11 +76,13 @@ TEST(QueryRegistry, verifyUniqueIdFromStringEnforcesUniqueness) { TEST(QueryRegistry, verifyIdIsUnregisteredAfterUse) { QueryRegistry registry{}; { - auto optionalQueryId = registry.uniqueIdFromString("01123581321345589144"); + auto optionalQueryId = + registry.uniqueIdFromString("01123581321345589144", "my-query"); EXPECT_TRUE(optionalQueryId.has_value()); } { - auto optionalQueryId = registry.uniqueIdFromString("01123581321345589144"); + auto optionalQueryId = + registry.uniqueIdFromString("01123581321345589144", "my-query"); EXPECT_TRUE(optionalQueryId.has_value()); } } @@ -89,8 +92,10 @@ TEST(QueryRegistry, verifyIdIsUnregisteredAfterUse) { TEST(QueryRegistry, demonstrateRegistryLocalUniqueness) { QueryRegistry registryOne{}; QueryRegistry registryTwo{}; - auto optQidOne = registryOne.uniqueIdFromString("01123581321345589144"); - auto optQidTwo = registryTwo.uniqueIdFromString("01123581321345589144"); + auto optQidOne = + registryOne.uniqueIdFromString("01123581321345589144", "my-query"); + auto optQidTwo = + registryTwo.uniqueIdFromString("01123581321345589144", "my-query"); ASSERT_TRUE(optQidOne.has_value()); ASSERT_TRUE(optQidTwo.has_value()); // The QueryId object doesn't know anything about registries, @@ -106,7 +111,7 @@ TEST(QueryRegistry, performCleanupFromDestroyedRegistry) { std::unique_ptr holder; { QueryRegistry registry{}; - holder = std::make_unique(registry.uniqueId()); + holder = std::make_unique(registry.uniqueId("my-query")); } } @@ -114,7 +119,7 @@ TEST(QueryRegistry, performCleanupFromDestroyedRegistry) { TEST(QueryRegistry, verifyCancellationHandleIsCreated) { QueryRegistry registry{}; - auto queryId = registry.uniqueId(); + auto queryId = registry.uniqueId("my-query"); auto handle1 = registry.getCancellationHandle(queryId.toQueryId()); auto handle2 = registry.getCancellationHandle(queryId.toQueryId()); @@ -138,24 +143,27 @@ TEST(QueryRegistry, verifyCancellationHandleIsNullptrIfNotPresent) { // _____________________________________________________________________________ TEST(QueryRegistry, verifyGetActiveQueriesReturnsAllActiveQueries) { + using MapType = ad_utility::HashMap; QueryRegistry registry{}; EXPECT_THAT(registry.getActiveQueries(), IsEmpty()); { - auto queryId1 = registry.uniqueId(); + auto queryId1 = registry.uniqueId("my-query"); - EXPECT_THAT(registry.getActiveQueries(), ElementsAre(queryId1.toQueryId())); + EXPECT_THAT(registry.getActiveQueries(), + ContainerEq(MapType{{queryId1.toQueryId(), "my-query"}})); { - auto queryId2 = registry.uniqueId(); + auto queryId2 = registry.uniqueId("other-query"); - EXPECT_THAT( - registry.getActiveQueries(), - UnorderedElementsAre(queryId1.toQueryId(), queryId2.toQueryId())); + EXPECT_THAT(registry.getActiveQueries(), + ContainerEq(MapType{{queryId1.toQueryId(), "my-query"}, + {queryId2.toQueryId(), "other-query"}})); } - EXPECT_THAT(registry.getActiveQueries(), ElementsAre(queryId1.toQueryId())); + EXPECT_THAT(registry.getActiveQueries(), + ContainerEq(MapType{{queryId1.toQueryId(), "my-query"}})); } EXPECT_THAT(registry.getActiveQueries(), IsEmpty()); diff --git a/test/WebSocketSessionTest.cpp b/test/WebSocketSessionTest.cpp index 898f1facd3..c01dff3f6a 100644 --- a/test/WebSocketSessionTest.cpp +++ b/test/WebSocketSessionTest.cpp @@ -180,7 +180,7 @@ ASYNC_TEST(WebSocketSession, verifySessionEndsWhenServerIsDoneSending) { ASYNC_TEST(WebSocketSession, verifyCancelStringTriggersCancellation) { auto c = co_await createTestContainer(ioContext); - auto queryId = c.registry_.uniqueIdFromString("some-id"); + auto queryId = c.registry_.uniqueIdFromString("some-id", "my-query"); ASSERT_TRUE(queryId.has_value()); auto cancellationHandle = c.registry_.getCancellationHandle(queryId->toQueryId()); @@ -285,7 +285,7 @@ ASYNC_TEST(WebSocketSession, verifyWrongExecutorConfigThrows) { ASYNC_TEST(WebSocketSession, verifyCancelOnCloseStringTriggersCancellation) { auto c = co_await createTestContainer(ioContext); - auto queryId = c.registry_.uniqueIdFromString("some-id"); + auto queryId = c.registry_.uniqueIdFromString("some-id", "my-query"); ASSERT_TRUE(queryId.has_value()); auto cancellationHandle = c.registry_.getCancellationHandle(queryId->toQueryId()); @@ -353,7 +353,7 @@ ASYNC_TEST(WebSocketSession, verifyCancelOnCloseStringTriggersCancellation) { ASYNC_TEST(WebSocketSession, verifyWithoutClientActionNoCancelDoesHappen) { auto c = co_await createTestContainer(ioContext); - auto queryId = c.registry_.uniqueIdFromString("some-id"); + auto queryId = c.registry_.uniqueIdFromString("some-id", "my-query"); ASSERT_TRUE(queryId.has_value()); auto cancellationHandle = c.registry_.getCancellationHandle(queryId->toQueryId());