diff --git a/common/zmqclient.cpp b/common/zmqclient.cpp index 5a84160e9..748d36e7b 100644 --- a/common/zmqclient.cpp +++ b/common/zmqclient.cpp @@ -9,6 +9,7 @@ #include #include #include "zmqclient.h" +#include "zmqserver.h" #include "binaryserializer.h" using namespace std; @@ -114,26 +115,9 @@ void ZmqClient::connect() m_connected = true; } -void ZmqClient::sendMsg( - const std::string& dbName, - const std::string& tableName, - const std::vector& kcos) +void ZmqClient::sendRaw(const char* buffer, size_t size) { - int serializedlen = (int)BinarySerializer::serializeBuffer( - m_sendbuffer.data(), - m_sendbuffer.size(), - dbName, - tableName, - kcos); - - if (serializedlen >= MQ_RESPONSE_MAX_COUNT) - { - SWSS_LOG_THROW("ZmqClient sendMsg message was too big (buffer size %d bytes, got %d), reduce the message size, message DROPPED", - MQ_RESPONSE_MAX_COUNT, - serializedlen); - } - - SWSS_LOG_DEBUG("sending: %d", serializedlen); + SWSS_LOG_DEBUG("sending: %zu", size); int zmq_err = 0; int retry_delay = 10; int rc = 0; @@ -144,12 +128,12 @@ void ZmqClient::sendMsg( std::lock_guard lock(m_socketMutex); // Use none block mode to use all bandwidth: http://api.zeromq.org/2-1%3Azmq-send - rc = zmq_send(m_socket, m_sendbuffer.data(), serializedlen, ZMQ_NOBLOCK); + rc = zmq_send(m_socket, buffer, size, ZMQ_NOBLOCK); } if (rc >= 0) { - SWSS_LOG_DEBUG("zmq sended %d bytes", serializedlen); + SWSS_LOG_DEBUG("zmq sended %zu bytes", size); return; } @@ -192,9 +176,31 @@ void ZmqClient::sendMsg( } // failed after retry - auto message = "zmq send failed, endpoint: " + m_endpoint + ", zmqerrno: " + to_string(zmq_err) + ":" + zmq_strerror(zmq_err) + ", msg length:" + to_string(serializedlen); + auto message = "zmq send failed, endpoint: " + m_endpoint + ", zmqerrno: " + to_string(zmq_err) + ":" + zmq_strerror(zmq_err) + ", msg length:" + to_string(size); SWSS_LOG_ERROR("%s", message.c_str()); throw system_error(make_error_code(errc::io_error), message); } +void ZmqClient::sendMsg( + const std::string& dbName, + const std::string& tableName, + const std::vector& kcos) +{ + int serializedlen = (int)BinarySerializer::serializeBuffer( + m_sendbuffer.data(), + m_sendbuffer.size(), + dbName, + tableName, + kcos); + + if (serializedlen >= MQ_RESPONSE_MAX_COUNT) + { + SWSS_LOG_THROW("ZmqClient sendMsg message was too big (buffer size %d bytes, got %d), reduce the message size, message DROPPED", + MQ_RESPONSE_MAX_COUNT, + serializedlen); + } + + sendRaw(m_sendbuffer.data(), serializedlen); +} + } diff --git a/common/zmqclient.h b/common/zmqclient.h index adc36b053..1d31523a5 100644 --- a/common/zmqclient.h +++ b/common/zmqclient.h @@ -5,7 +5,7 @@ #include #include #include -#include "zmqserver.h" +#include "table.h" namespace swss { @@ -23,6 +23,8 @@ class ZmqClient void sendMsg(const std::string& dbName, const std::string& tableName, const std::vector& kcos); + + void sendRaw(const char* buffer, size_t size); private: void initialize(const std::string& endpoint, const std::string& vrf); diff --git a/common/zmqconsumerstatetable.cpp b/common/zmqconsumerstatetable.cpp index 5f58482f9..3472d05dc 100644 --- a/common/zmqconsumerstatetable.cpp +++ b/common/zmqconsumerstatetable.cpp @@ -39,6 +39,11 @@ ZmqConsumerStateTable::ZmqConsumerStateTable(DBConnector *db, const std::string SWSS_LOG_DEBUG("ZmqConsumerStateTable ctor tableName: %s", tableName.c_str()); } +ZmqConsumerStateTable::~ZmqConsumerStateTable() +{ + m_zmqServer.unregisterMessageHandler(m_db->getDbName(), getTableName()); +} + void ZmqConsumerStateTable::handleReceivedData(const std::vector> &kcos) { for (auto kco : kcos) diff --git a/common/zmqconsumerstatetable.h b/common/zmqconsumerstatetable.h index dece60bd7..ddcc6403f 100644 --- a/common/zmqconsumerstatetable.h +++ b/common/zmqconsumerstatetable.h @@ -19,6 +19,7 @@ class ZmqConsumerStateTable : public Selectable, public TableBase, public ZmqMes static constexpr int DEFAULT_POP_BATCH_SIZE = 128; ZmqConsumerStateTable(DBConnector *db, const std::string &tableName, ZmqServer &zmqServer, int popBatchSize = DEFAULT_POP_BATCH_SIZE, int pri = 0, bool dbPersistence = false); + ~ZmqConsumerStateTable(); /* Get multiple pop elements */ void pops(std::deque &vkco, const std::string &prefix = EMPTY_PREFIX); diff --git a/common/zmqserver.cpp b/common/zmqserver.cpp index dca107405..02b909649 100644 --- a/common/zmqserver.cpp +++ b/common/zmqserver.cpp @@ -18,7 +18,8 @@ ZmqServer::ZmqServer(const std::string& endpoint) ZmqServer::ZmqServer(const std::string& endpoint, const std::string& vrf) : m_endpoint(endpoint), - m_vrf(vrf) + m_vrf(vrf), + m_proxy_mode(false) { m_buffer.resize(MQ_RESPONSE_MAX_COUNT); m_runThread = true; @@ -33,11 +34,23 @@ ZmqServer::~ZmqServer() m_mqPollThread->join(); } +void ZmqServer::enableProxyMode(const std::string& proxy_endpoint) +{ + m_proxy_client = make_unique(proxy_endpoint); + m_proxy_mode = true; +} + +bool ZmqServer::isProxyMode() const { + return m_proxy_mode; +} + void ZmqServer::registerMessageHandler( const std::string dbName, const std::string tableName, ZmqMessageHandler* handler) { + std::lock_guard lock(m_handlerMapMutext); + auto dbResult = m_HandlerMap.insert(pair>(dbName, map())); if (dbResult.second) { SWSS_LOG_DEBUG("ZmqServer add handler mapping for db: %s", dbName.c_str()); @@ -49,10 +62,31 @@ void ZmqServer::registerMessageHandler( } } +void ZmqServer::unregisterMessageHandler(const std::string &dbName, const std::string &tableName) +{ + std::lock_guard lock(m_handlerMapMutext); + + SWSS_LOG_DEBUG("ZmqServer unregister handler for db: %s, table: %s", dbName.c_str(), tableName.c_str()); + + auto db = m_HandlerMap.find(dbName); + if (db == m_HandlerMap.end()) { + SWSS_LOG_ERROR("ZmqServer can't unregister a handler for db: %s - not found", dbName.c_str()); + return; + } + + auto removed = db->second.erase(tableName); + if (!removed) { + SWSS_LOG_ERROR("ZmqServer can't unregister a handler for db: %s table %s - not found", dbName.c_str(), tableName.c_str()); + return; + } +} + ZmqMessageHandler* ZmqServer::findMessageHandler( const std::string dbName, const std::string tableName) { + std::lock_guard lock(m_handlerMapMutext); + auto dbMappingIter = m_HandlerMap.find(dbName); if (dbMappingIter == m_HandlerMap.end()) { SWSS_LOG_DEBUG("ZmqServer can't find any handler for db: %s", dbName.c_str()); @@ -77,12 +111,15 @@ void ZmqServer::handleReceivedData(const char* buffer, const size_t size) // find handler auto handler = findMessageHandler(dbName, tableName); - if (handler == nullptr) { - SWSS_LOG_WARN("ZmqServer can't find handler for received message: %s", buffer); - return; + if (handler) { + handler->handleReceivedData(kcos); + } else { + if (isProxyMode()) { + m_proxy_client->sendRaw(buffer, size); + } else { + SWSS_LOG_WARN("ZmqServer can't find handler for received message: %.*s", (int)size, buffer); + } } - - handler->handleReceivedData(kcos); } void ZmqServer::mqPollThread() diff --git a/common/zmqserver.h b/common/zmqserver.h index 8afe18d7c..ede2bb232 100644 --- a/common/zmqserver.h +++ b/common/zmqserver.h @@ -4,7 +4,10 @@ #include #include #include +#include +#include #include "table.h" +#include "zmqclient.h" #define MQ_RESPONSE_MAX_COUNT (16*1024*1024) #define MQ_SIZE 100 @@ -34,15 +37,22 @@ class ZmqServer ZmqServer(const std::string& endpoint, const std::string& vrf); ~ZmqServer(); + void enableProxyMode(const std::string& proxy_endpoint); + void registerMessageHandler( const std::string dbName, const std::string tableName, ZmqMessageHandler* handler); + void unregisterMessageHandler(const std::string &dbName, + const std::string &tableName); + private: void handleReceivedData(const char* buffer, const size_t size); void mqPollThread(); + + bool isProxyMode() const; ZmqMessageHandler* findMessageHandler(const std::string dbName, const std::string tableName); @@ -56,6 +66,12 @@ class ZmqServer std::string m_vrf; + std::atomic m_proxy_mode; + + std::unique_ptr m_proxy_client; + + std::mutex m_handlerMapMutext; + std::map> m_HandlerMap; }; diff --git a/tests/Makefile.am b/tests/Makefile.am index 9642b09ab..37a3e1b68 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -44,6 +44,7 @@ tests_tests_SOURCES = tests/redis_ut.cpp \ tests/profileprovider_ut.cpp \ tests/c_api_ut.cpp \ tests/performancetimer_ut.cpp \ + tests/zmq_proxy_ut.cpp \ tests/main.cpp tests_tests_CFLAGS = $(DBGFLAGS) $(AM_CFLAGS) $(CFLAGS_COMMON) $(CFLAGS_GTEST) $(LIBNL_CFLAGS) diff --git a/tests/c_api_ut.cpp b/tests/c_api_ut.cpp index ed814607e..84d6da149 100644 --- a/tests/c_api_ut.cpp +++ b/tests/c_api_ut.cpp @@ -322,12 +322,12 @@ TEST(c_api, ZmqConsumerProducerStateTable) { EXPECT_EQ(kfvFieldsValues(kfvs[1]).size(), 0); } - // Server must be freed first to safely release message handlers (ZmqConsumerStateTable) - SWSSZmqServer_free(srv); - + // The message handlers (ZmqConsumerStateTable) must be freed first to safely unregister from the Server SWSSZmqProducerStateTable_free(pst); SWSSZmqConsumerStateTable_free(cst); + SWSSZmqServer_free(srv); + SWSSZmqClient_free(cli); SWSSDBConnector_flushdb(db); diff --git a/tests/zmq_proxy_ut.cpp b/tests/zmq_proxy_ut.cpp new file mode 100644 index 000000000..1535bb74c --- /dev/null +++ b/tests/zmq_proxy_ut.cpp @@ -0,0 +1,140 @@ +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "common/dbconnector.h" +#include "common/select.h" +#include "common/selectableevent.h" +#include "common/table.h" +#include "common/zmqserver.h" +#include "common/zmqclient.h" +#include "common/zmqproducerstatetable.h" +#include "common/zmqconsumerstatetable.h" + +using namespace std; +using namespace swss; + +#define TEST_DB "APPL_DB" +#define TEST_TABLE_FOR_SERVER "TEST_TABLE_PROXY_FORWARD" +#define TEST_TABLE_FOR_PROXY "TEST_TABLE_PROXY_CONSUME" + +#define PROXY_ADDR "tcp://*:5001" +#define PROXY_ENDPOINT "tcp://localhost:5001" +#define SERVER_ADDR "tcp://*:5002" +#define SERVER_ENDPOINT "tcp://localhost:5002" + +#define SELECT_TIMEOUT_EXPECT_RECEIVE 10000 +#define SELECT_TIMEOUT_EXPECT_NO_DATA 1000 + +std::vector consume(ZmqConsumerStateTable &c, size_t keys_to_receive = 1) +{ + Select cs; + cs.addSelectable(&c); + + Selectable *selectcs; + std::vector received; + + while (received.size() < keys_to_receive) + { + std::deque vkco; + int ret = cs.select(&selectcs, SELECT_TIMEOUT_EXPECT_RECEIVE, true); + EXPECT_EQ(ret, Select::OBJECT); + + c.pops(vkco); + received.insert(end(received), begin(vkco), end(vkco)); + } + + EXPECT_EQ(received.size(), keys_to_receive) << "Unexpected number of keys received"; + + // Verify that all data are read + int ret = cs.select(&selectcs, SELECT_TIMEOUT_EXPECT_NO_DATA); + EXPECT_EQ(ret, Select::TIMEOUT) << "Unexpected data received in the consumer"; + + return received; +} + +void produce(ZmqClient &client, const string& table, const std::vector& vkco) +{ + DBConnector db(TEST_DB, 0, true); + ZmqProducerStateTable p(&db, table, client, false); + p.set(vkco); +} + +std::vector generate_vkco(size_t keys_count = 1, const string& key_prefix = "") +{ + std::vector data; + + for (size_t i = 0; i < keys_count; i++) + { + data.emplace_back(KeyOpFieldsValuesTuple{key_prefix + "test_key_" + to_string(i), "SET", std::vector { + FieldValueTuple("test_field0", "test_value0"), + FieldValueTuple("test_field1", "test_value1") + }}); + } + + return data; +} + +void validate_vkco(const std::vector& vkco, size_t expected_keys_count = 1, const string& key_prefix = "") +{ + ASSERT_EQ(vkco.size(), expected_keys_count); + + for (size_t i = 0; i < expected_keys_count; i++) + { + ASSERT_EQ(kfvKey(vkco[i]), key_prefix + "test_key_" + to_string(i)); + } +} + +TEST(ZmqProxy, proxy_forward) +{ + const size_t key_to_test = 5; + std::vector data = generate_vkco(key_to_test); + DBConnector db(TEST_DB, 0, true); + + // Setup proxy server w/o any consumers to forward all + ZmqServer proxyServer(PROXY_ADDR); + proxyServer.enableProxyMode(SERVER_ENDPOINT); + ASSERT_TRUE(proxyServer.isProxyMode()); + + ZmqServer server(SERVER_ADDR); + ASSERT_FALSE(server.isProxyMode()); + ZmqConsumerStateTable serverConsumer(&db, TEST_TABLE_FOR_SERVER, server); + + ZmqClient client(PROXY_ENDPOINT); + produce(client, TEST_TABLE_FOR_SERVER, data); + + auto recevied = consume(serverConsumer, key_to_test); + validate_vkco(recevied, key_to_test); +} + +TEST(ZmqProxy, proxy_forward_consume) +{ + const size_t key_to_test = 5; + std::vector foward_data = generate_vkco(key_to_test, "forward"); + std::vector consume_data = generate_vkco(key_to_test, "consume"); + DBConnector db(TEST_DB, 0, true); + + // Setup proxy server with a consumer for TEST_TABLE_FOR_PROXY table + ZmqServer proxyServer(PROXY_ADDR); + proxyServer.enableProxyMode(SERVER_ENDPOINT); + ZmqConsumerStateTable proxyConsumer(&db, TEST_TABLE_FOR_PROXY, proxyServer); + + ZmqServer server(SERVER_ADDR); + ZmqConsumerStateTable serverConsumer(&db, TEST_TABLE_FOR_SERVER, server); + // This should not receive any data since the table is consumed in the proxy + ZmqConsumerStateTable serverConsumerUnexpectedTable(&db, TEST_TABLE_FOR_PROXY, server); + + ZmqClient client(PROXY_ENDPOINT); + produce(client, TEST_TABLE_FOR_SERVER, foward_data); + produce(client, TEST_TABLE_FOR_PROXY, consume_data); + + auto recevied = consume(serverConsumer, key_to_test); + validate_vkco(recevied, key_to_test, "forward"); + + consume(serverConsumerUnexpectedTable, 0); + + recevied = consume(proxyConsumer, key_to_test); + validate_vkco(recevied, key_to_test, "consume"); +}