From fcea90036a6e2c1324ea56e92285903cb5384f7a Mon Sep 17 00:00:00 2001 From: Nijat K <nijat.khanbabayev@gmail.com> Date: Mon, 25 Nov 2024 01:10:45 -0500 Subject: [PATCH 1/8] Add dynamic websockets support Signed-off-by: Nijat K <nijat.khanbabayev@gmail.com> --- cpp/csp/adapters/websocket/CMakeLists.txt | 6 + .../websocket/ClientAdapterManager.cpp | 133 ++--- .../adapters/websocket/ClientAdapterManager.h | 38 +- .../ClientConnectionRequestAdapter.cpp | 48 ++ .../ClientConnectionRequestAdapter.h | 45 ++ .../websocket/ClientHeaderUpdateAdapter.cpp | 20 +- .../websocket/ClientHeaderUpdateAdapter.h | 11 +- .../adapters/websocket/ClientInputAdapter.cpp | 54 +- .../adapters/websocket/ClientInputAdapter.h | 7 +- .../websocket/ClientOutputAdapter.cpp | 17 +- .../adapters/websocket/ClientOutputAdapter.h | 15 +- .../websocket/WebsocketClientTypes.cpp | 13 + .../adapters/websocket/WebsocketClientTypes.h | 26 + .../adapters/websocket/WebsocketEndpoint.cpp | 49 +- .../adapters/websocket/WebsocketEndpoint.h | 211 ++++---- .../websocket/WebsocketEndpointManager.cpp | 432 ++++++++++++++++ .../websocket/WebsocketEndpointManager.h | 152 ++++++ cpp/csp/python/Conversions.h | 24 + .../python/adapters/websocketadapterimpl.cpp | 37 +- csp/adapters/dynamic_adapter_utils.py | 6 + csp/adapters/websocket.py | 157 +++++- csp/adapters/websocket_types.py | 19 + csp/tests/adapters/test_websocket.py | 481 ++++++++++++++++-- 23 files changed, 1695 insertions(+), 306 deletions(-) create mode 100644 cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp create mode 100644 cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.h create mode 100644 cpp/csp/adapters/websocket/WebsocketClientTypes.cpp create mode 100644 cpp/csp/adapters/websocket/WebsocketClientTypes.h create mode 100644 cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp create mode 100644 cpp/csp/adapters/websocket/WebsocketEndpointManager.h create mode 100644 csp/adapters/dynamic_adapter_utils.py diff --git a/cpp/csp/adapters/websocket/CMakeLists.txt b/cpp/csp/adapters/websocket/CMakeLists.txt index 402d01b69..bb85d20da 100644 --- a/cpp/csp/adapters/websocket/CMakeLists.txt +++ b/cpp/csp/adapters/websocket/CMakeLists.txt @@ -1,20 +1,26 @@ csp_autogen( csp.adapters.websocket_types websocket_types WEBSOCKET_HEADER WEBSOCKET_SOURCE ) set(WS_CLIENT_HEADER_FILES + WebsocketClientTypes.h ClientAdapterManager.h ClientInputAdapter.h ClientOutputAdapter.h ClientHeaderUpdateAdapter.h + ClientConnectionRequestAdapter.h WebsocketEndpoint.h + WebsocketEndpointManager.h ${WEBSOCKET_HEADER} ) set(WS_CLIENT_SOURCE_FILES + WebsocketClientTypes.cpp ClientAdapterManager.cpp ClientInputAdapter.cpp ClientOutputAdapter.cpp ClientHeaderUpdateAdapter.cpp + ClientConnectionRequestAdapter.cpp WebsocketEndpoint.cpp + WebsocketEndpointManager.cpp ${WS_CLIENT_HEADER_FILES} ${WEBSOCKET_SOURCE} ) diff --git a/cpp/csp/adapters/websocket/ClientAdapterManager.cpp b/cpp/csp/adapters/websocket/ClientAdapterManager.cpp index 423f2a234..ab609462b 100644 --- a/cpp/csp/adapters/websocket/ClientAdapterManager.cpp +++ b/cpp/csp/adapters/websocket/ClientAdapterManager.cpp @@ -1,124 +1,59 @@ #include <csp/adapters/websocket/ClientAdapterManager.h> -namespace csp { - -INIT_CSP_ENUM( adapters::websocket::ClientStatusType, - "ACTIVE", - "GENERIC_ERROR", - "CONNECTION_FAILED", - "CLOSED", - "MESSAGE_SEND_FAIL", -); - -} - -// With TLS namespace csp::adapters::websocket { ClientAdapterManager::ClientAdapterManager( Engine* engine, const Dictionary & properties ) : AdapterManager( engine ), - m_active( false ), - m_shouldRun( false ), - m_endpoint( std::make_unique<WebsocketEndpoint>( properties ) ), - m_inputAdapter( nullptr ), - m_outputAdapter( nullptr ), - m_updateAdapter( nullptr ), - m_thread( nullptr ), - m_properties( properties ) -{ }; + m_properties( properties ) +{ } ClientAdapterManager::~ClientAdapterManager() -{ }; +{ } -void ClientAdapterManager::start( DateTime starttime, DateTime endtime ) -{ - AdapterManager::start( starttime, endtime ); - - m_shouldRun = true; - m_endpoint -> setOnOpen( - [ this ]() { - m_active = true; - pushStatus( StatusLevel::INFO, ClientStatusType::ACTIVE, "Connected successfully" ); - } - ); - m_endpoint -> setOnFail( - [ this ]( const std::string& reason ) { - std::stringstream ss; - ss << "Connection Failure: " << reason; - m_active = false; - pushStatus( StatusLevel::ERROR, ClientStatusType::CONNECTION_FAILED, ss.str() ); - } - ); - if( m_inputAdapter ) { - m_endpoint -> setOnMessage( - [ this ]( void* c, size_t t ) { - PushBatch batch( m_engine -> rootEngine() ); - m_inputAdapter -> processMessage( c, t, &batch ); - } - ); - } else { - // if a user doesn't call WebsocketAdapterManager.subscribe, no inputadapter will be created - // but we still need something to avoid on_message_cb not being set in the endpoint. - m_endpoint -> setOnMessage( []( void* c, size_t t ){} ); - } - m_endpoint -> setOnClose( - [ this ]() { - m_active = false; - pushStatus( StatusLevel::INFO, ClientStatusType::CLOSED, "Connection closed" ); - } - ); - m_endpoint -> setOnSendFail( - [ this ]( const std::string& s ) { - std::stringstream ss; - ss << "Failed to send: " << s; - pushStatus( StatusLevel::ERROR, ClientStatusType::MESSAGE_SEND_FAIL, ss.str() ); - } - ); +WebsocketEndpointManager* ClientAdapterManager::getWebsocketManager(){ + if( m_endpointManager == nullptr ) + return nullptr; + return m_endpointManager.get(); +} - m_thread = std::make_unique<std::thread>( [ this ]() { - while( m_shouldRun ) - { - m_endpoint -> run(); - m_active = false; - if( m_shouldRun ) sleep( m_properties.get<TimeDelta>( "reconnect_interval" ) ); - } - }); -}; +void ClientAdapterManager::start(DateTime starttime, DateTime endtime) { + AdapterManager::start(starttime, endtime); + if (m_endpointManager != nullptr) + m_endpointManager -> start(starttime, endtime); +} void ClientAdapterManager::stop() { AdapterManager::stop(); - - m_shouldRun=false; - if( m_active ) m_endpoint->stop(); - if( m_thread ) m_thread->join(); -}; + if (m_endpointManager != nullptr) + m_endpointManager -> stop(); +} PushInputAdapter* ClientAdapterManager::getInputAdapter(CspTypePtr & type, PushMode pushMode, const Dictionary & properties) -{ - if (m_inputAdapter == nullptr) - { - m_inputAdapter = m_engine -> createOwnedObject<ClientInputAdapter>( - // m_engine, - type, - pushMode, - properties - ); - } - return m_inputAdapter; -}; +{ + if (m_endpointManager == nullptr) + m_endpointManager = std::make_unique<WebsocketEndpointManager>(this, m_properties, m_engine); + return m_endpointManager -> getInputAdapter( type, pushMode, properties ); +} -OutputAdapter* ClientAdapterManager::getOutputAdapter() +OutputAdapter* ClientAdapterManager::getOutputAdapter( const Dictionary & properties ) { - if (m_outputAdapter == nullptr) m_outputAdapter = m_engine -> createOwnedObject<ClientOutputAdapter>(*m_endpoint); - - return m_outputAdapter; + if (m_endpointManager == nullptr) + m_endpointManager = std::make_unique<WebsocketEndpointManager>(this, m_properties, m_engine); + return m_endpointManager -> getOutputAdapter( properties ); } OutputAdapter * ClientAdapterManager::getHeaderUpdateAdapter() { - if (m_updateAdapter == nullptr) m_updateAdapter = m_engine -> createOwnedObject<ClientHeaderUpdateOutputAdapter>( m_endpoint -> getProperties() ); + if (m_endpointManager == nullptr) + m_endpointManager = std::make_unique<WebsocketEndpointManager>(this, m_properties, m_engine); + return m_endpointManager -> getHeaderUpdateAdapter(); +} - return m_updateAdapter; +OutputAdapter * ClientAdapterManager::getConnectionRequestAdapter( const Dictionary & properties ) +{ + if (m_endpointManager == nullptr) + m_endpointManager = std::make_unique<WebsocketEndpointManager>(this, m_properties, m_engine); + return m_endpointManager -> getConnectionRequestAdapter( properties ); } DateTime ClientAdapterManager::processNextSimTimeSlice( DateTime time ) diff --git a/cpp/csp/adapters/websocket/ClientAdapterManager.h b/cpp/csp/adapters/websocket/ClientAdapterManager.h index 62577d769..e530ee10f 100644 --- a/cpp/csp/adapters/websocket/ClientAdapterManager.h +++ b/cpp/csp/adapters/websocket/ClientAdapterManager.h @@ -2,8 +2,9 @@ #define _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_ADAPTERMGR_H #include <csp/adapters/websocket/WebsocketEndpoint.h> +#include <csp/adapters/websocket/WebsocketEndpointManager.h> +#include <csp/adapters/websocket/WebsocketClientTypes.h> #include <csp/adapters/websocket/ClientInputAdapter.h> -#include <csp/adapters/websocket/ClientOutputAdapter.h> #include <csp/adapters/websocket/ClientHeaderUpdateAdapter.h> #include <csp/core/Enum.h> #include <csp/core/Hash.h> @@ -15,30 +16,15 @@ #include <chrono> #include <iomanip> #include <iostream> +#include <vector> +#include <unordered_set> namespace csp::adapters::websocket { using namespace csp; -struct WebsocketClientStatusTypeTraits -{ - enum _enum : unsigned char - { - ACTIVE = 0, - GENERIC_ERROR = 1, - CONNECTION_FAILED = 2, - CLOSED = 3, - MESSAGE_SEND_FAIL = 4, - - NUM_TYPES - }; - -protected: - _enum m_value; -}; - -using ClientStatusType = Enum<WebsocketClientStatusTypeTraits>; +class WebsocketEndpointManager; class ClientAdapterManager final : public AdapterManager { @@ -57,23 +43,17 @@ class ClientAdapterManager final : public AdapterManager void stop() override; + WebsocketEndpointManager* getWebsocketManager(); PushInputAdapter * getInputAdapter( CspTypePtr & type, PushMode pushMode, const Dictionary & properties ); - OutputAdapter * getOutputAdapter(); + OutputAdapter * getOutputAdapter( const Dictionary & properties ); OutputAdapter * getHeaderUpdateAdapter(); + OutputAdapter * getConnectionRequestAdapter( const Dictionary & properties ); DateTime processNextSimTimeSlice( DateTime time ) override; private: - // need some client info - - bool m_active; - bool m_shouldRun; - std::unique_ptr<WebsocketEndpoint> m_endpoint; - ClientInputAdapter* m_inputAdapter; - ClientOutputAdapter* m_outputAdapter; - ClientHeaderUpdateOutputAdapter* m_updateAdapter; - std::unique_ptr<std::thread> m_thread; Dictionary m_properties; + std::unique_ptr<WebsocketEndpointManager> m_endpointManager; }; } diff --git a/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp new file mode 100644 index 000000000..98b76c4da --- /dev/null +++ b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp @@ -0,0 +1,48 @@ +#include <csp/adapters/websocket/ClientConnectionRequestAdapter.h> +#include <csp/python/Conversions.h> +#include <Python.h> + +namespace csp::adapters::websocket { + +ClientConnectionRequestAdapter::ClientConnectionRequestAdapter( + Engine * engine, + WebsocketEndpointManager * websocketManager, + bool is_subscribe, + size_t caller_id, + boost::asio::strand<boost::asio::io_context::executor_type>& strand + +) : OutputAdapter( engine ), + m_websocketManager( websocketManager ), + m_strand( strand ), + m_isSubscribe( is_subscribe ), + m_callerId( caller_id ), + m_checkPerformed( is_subscribe ? false : true ) // we only need to check for pruned input adapters +{} + +void ClientConnectionRequestAdapter::executeImpl() +{ + // One-time check for pruned status + if (unlikely(!m_checkPerformed)) { + m_isPruned = m_websocketManager->adapterPruned(m_callerId); + m_checkPerformed = true; + } + + // Early return if pruned + if (unlikely(m_isPruned)) + return; + + auto raw_val = input()->lastValueTyped<PyObject*>(); + auto val = python::fromPython<std::vector<Dictionary>>(raw_val); + + // We intentionally post here, we want the thread running + // the strand to handle the connection request. We want to keep + // all updates to internal data structures at graph run-time + // to that thread. + boost::asio::post(m_strand, [this, val=std::move(val)]() { + for(const auto& conn_req: val) { + m_websocketManager->handleConnectionRequest(conn_req, m_callerId, m_isSubscribe); + } + }); +}; + +} \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.h b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.h new file mode 100644 index 000000000..505ea1164 --- /dev/null +++ b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.h @@ -0,0 +1,45 @@ +#ifndef _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_CONNECTIONREQUESTADAPTER_H +#define _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_CONNECTIONREQUESTADAPTER_H + +#include <csp/adapters/websocket/ClientAdapterManager.h> +#include <csp/engine/Dictionary.h> +#include <csp/engine/OutputAdapter.h> +#include <csp/adapters/utils/MessageWriter.h> +#include <csp/adapters/websocket/csp_autogen/websocket_types.h> + +namespace csp::adapters::websocket +{ +using namespace csp::autogen; + +class ClientAdapterManager; +class WebsocketEndpointManager; + +class ClientConnectionRequestAdapter final: public OutputAdapter +{ +public: + ClientConnectionRequestAdapter( + Engine * engine, + WebsocketEndpointManager * websocketManager, + bool isSubscribe, + size_t callerId, + boost::asio::strand<boost::asio::io_context::executor_type>& strand + ); + + void executeImpl() override; + + const char * name() const override { return "WebsocketClientConnectionRequestAdapter"; } + +private: + WebsocketEndpointManager* m_websocketManager; + boost::asio::strand<boost::asio::io_context::executor_type>& m_strand; + bool m_isSubscribe; + size_t m_callerId; + bool m_checkPerformed; + bool m_isPruned{false}; + +}; + +} + + +#endif \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.cpp b/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.cpp index 995bd7314..c25a368b7 100644 --- a/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.cpp @@ -2,19 +2,27 @@ namespace csp::adapters::websocket { +class WebsocketEndpointManager; + ClientHeaderUpdateOutputAdapter::ClientHeaderUpdateOutputAdapter( Engine * engine, - Dictionary& properties -) : OutputAdapter( engine ), m_properties( properties ) + WebsocketEndpointManager * mgr, + boost::asio::strand<boost::asio::io_context::executor_type>& strand +) : OutputAdapter( engine ), m_mgr( mgr ), m_strand( strand ) { }; void ClientHeaderUpdateOutputAdapter::executeImpl() { - DictionaryPtr headers = m_properties.get<DictionaryPtr>("headers"); - for( auto& update : input() -> lastValueTyped<std::vector<WebsocketHeaderUpdate::Ptr>>() ) - { - if( update -> key_isSet() && update -> value_isSet() ) headers->update( update->key(), update->value() ); + Dictionary headers; + for (auto& update : input()->lastValueTyped<std::vector<WebsocketHeaderUpdate::Ptr>>()) { + if (update->key_isSet() && update->value_isSet()) { + headers.update(update->key(), update->value()); + } } + boost::asio::post(m_strand, [this, headers=std::move(headers)]() { + auto endpoint = m_mgr -> getNonDynamicEndpoint(); + endpoint -> updateHeaders(std::move(headers)); + }); }; } \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.h b/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.h index d2c898a1e..88d0ec439 100644 --- a/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.h +++ b/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.h @@ -1,6 +1,7 @@ #ifndef _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_HEADERUPDATEADAPTER_H #define _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_HEADERUPDATEADAPTER_H +#include <csp/adapters/websocket/WebsocketEndpointManager.h> #include <csp/engine/Dictionary.h> #include <csp/engine/OutputAdapter.h> #include <csp/adapters/utils/MessageWriter.h> @@ -10,12 +11,15 @@ namespace csp::adapters::websocket { using namespace csp::autogen; +class WebsocketEndpointManager; + class ClientHeaderUpdateOutputAdapter final: public OutputAdapter { public: ClientHeaderUpdateOutputAdapter( Engine * engine, - Dictionary& properties + WebsocketEndpointManager * mgr, + boost::asio::strand<boost::asio::io_context::executor_type>& strand ); void executeImpl() override; @@ -23,7 +27,10 @@ class ClientHeaderUpdateOutputAdapter final: public OutputAdapter const char * name() const override { return "WebsocketClientHeaderUpdateAdapter"; } private: - Dictionary& m_properties; + WebsocketEndpointManager * m_mgr; + boost::asio::strand<boost::asio::io_context::executor_type>& m_strand; + + }; diff --git a/cpp/csp/adapters/websocket/ClientInputAdapter.cpp b/cpp/csp/adapters/websocket/ClientInputAdapter.cpp index e4b0b7ff7..cff97f11b 100644 --- a/cpp/csp/adapters/websocket/ClientInputAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientInputAdapter.cpp @@ -1,5 +1,4 @@ #include <csp/adapters/websocket/ClientInputAdapter.h> - namespace csp::adapters::websocket { @@ -7,8 +6,10 @@ ClientInputAdapter::ClientInputAdapter( Engine * engine, CspTypePtr & type, PushMode pushMode, - const Dictionary & properties -) : PushInputAdapter(engine, type, pushMode) + const Dictionary & properties, + bool dynamic +) : PushInputAdapter(engine, type, pushMode), + m_dynamic( dynamic ) { if( type -> type() != CspType::Type::STRUCT && type -> type() != CspType::Type::STRING ) @@ -21,8 +22,14 @@ ClientInputAdapter::ClientInputAdapter( if( !metaFieldMap.empty() && type -> type() != CspType::Type::STRUCT ) CSP_THROW( ValueError, "meta_field_map is not supported on non-struct types" ); } + if ( m_dynamic ){ + auto& actual_type = static_cast<const CspStructType &>( *type ); + auto& nested_type = actual_type.meta()-> field( "msg" ) -> type(); - m_converter = adapters::utils::MessageStructConverterCache::instance().create( type, properties ); + m_converter = adapters::utils::MessageStructConverterCache::instance().create( nested_type, properties ); + } + else + m_converter = adapters::utils::MessageStructConverterCache::instance().create( type, properties ); }; void ClientInputAdapter::processMessage( void* c, size_t t, PushBatch* batch ) @@ -39,4 +46,43 @@ void ClientInputAdapter::processMessage( void* c, size_t t, PushBatch* batch ) } +void ClientInputAdapter::processMessage( std::tuple<std::string, void*> data, size_t t, PushBatch* batch ) +{ + // Extract the source string and data pointer from tuple + std::string source = std::get<0>(data); + void* c = std::get<1>(data); + if ( m_dynamic ){ + auto& actual_type = static_cast<const CspStructType &>( *dataType() ); + auto& nested_type = actual_type.meta()-> field( "msg" ) -> type(); + auto true_val = actual_type.meta() -> create(); + actual_type.meta()->field("uri")->setValue( true_val.get(), source ); + + if( nested_type -> type() == CspType::Type::STRUCT ) + { + auto tick = m_converter -> asStruct( c, t ); + actual_type.meta()->field("msg")->setValue( true_val.get(), std::move(tick) ); + + pushTick( std::move(true_val), batch ); + } else if ( nested_type -> type() == CspType::Type::STRING ) + { + auto msg = std::string((char const*)c, t); + actual_type.meta()->field("msg")->setValue( true_val.get(), msg ); + + pushTick( std::move(true_val), batch ); + } + + } + else{ + if( dataType() -> type() == CspType::Type::STRUCT ) + { + auto tick = m_converter -> asStruct( c, t ); + pushTick( std::move(tick), batch ); + } else if ( dataType() -> type() == CspType::Type::STRING ) + { + pushTick( std::string((char const*)c, t), batch ); + } + } + +} + } \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientInputAdapter.h b/cpp/csp/adapters/websocket/ClientInputAdapter.h index bf3cb295f..a5ceda58e 100644 --- a/cpp/csp/adapters/websocket/ClientInputAdapter.h +++ b/cpp/csp/adapters/websocket/ClientInputAdapter.h @@ -16,17 +16,20 @@ class ClientInputAdapter final: public PushInputAdapter { Engine * engine, CspTypePtr & type, PushMode pushMode, - const Dictionary & properties + const Dictionary & properties, + bool dynamic ); void processMessage( void* c, size_t t, PushBatch* batch ); + void processMessage( std::tuple<std::string, void*> data, size_t t, PushBatch* batch ); private: adapters::utils::MessageStructConverterPtr m_converter; + const bool m_dynamic; }; } -#endif // _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_INPUTADAPTER_H \ No newline at end of file +#endif \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientOutputAdapter.cpp b/cpp/csp/adapters/websocket/ClientOutputAdapter.cpp index 3ef3c91ac..7b9bd83e1 100644 --- a/cpp/csp/adapters/websocket/ClientOutputAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientOutputAdapter.cpp @@ -4,14 +4,23 @@ namespace csp::adapters::websocket { ClientOutputAdapter::ClientOutputAdapter( Engine * engine, - WebsocketEndpoint& endpoint -) : OutputAdapter( engine ), m_endpoint( endpoint ) + WebsocketEndpointManager * websocketManager, + size_t caller_id, + net::io_context& ioc, + boost::asio::strand<boost::asio::io_context::executor_type>& strand +) : OutputAdapter( engine ), + m_websocketManager( websocketManager ), + m_callerId( caller_id ), + m_ioc( ioc ), + m_strand( strand ) { }; void ClientOutputAdapter::executeImpl() { const std::string & value = input() -> lastValueTyped<std::string>(); - m_endpoint.send( value ); -}; + boost::asio::post(m_strand, [this, value=value]() { + m_websocketManager->send(value, m_callerId); + }); +} } \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientOutputAdapter.h b/cpp/csp/adapters/websocket/ClientOutputAdapter.h index 905822e2f..d97bc8062 100644 --- a/cpp/csp/adapters/websocket/ClientOutputAdapter.h +++ b/cpp/csp/adapters/websocket/ClientOutputAdapter.h @@ -5,11 +5,13 @@ #include <csp/engine/Dictionary.h> #include <csp/engine/OutputAdapter.h> #include <csp/adapters/utils/MessageWriter.h> +#include <csp/adapters/websocket/ClientAdapterManager.h> namespace csp::adapters::websocket { class ClientAdapterManager; +class WebsocketEndpointManager; class ClientOutputAdapter final: public OutputAdapter { @@ -17,7 +19,11 @@ class ClientOutputAdapter final: public OutputAdapter public: ClientOutputAdapter( Engine * engine, - WebsocketEndpoint& endpoint + WebsocketEndpointManager * websocketManager, + size_t caller_id, + net::io_context& ioc, + boost::asio::strand<boost::asio::io_context::executor_type>& strand + // bool dynamic ); void executeImpl() override; @@ -25,7 +31,12 @@ class ClientOutputAdapter final: public OutputAdapter const char * name() const override { return "WebsocketClientOutputAdapter"; } private: - WebsocketEndpoint& m_endpoint; + WebsocketEndpointManager* m_websocketManager; + size_t m_callerId; + [[maybe_unused]] net::io_context& m_ioc; + boost::asio::strand<boost::asio::io_context::executor_type>& m_strand; + // bool m_dynamic; + // std::unordered_map<std::string, std::vector<bool>>& m_endpoint_consumers; }; } diff --git a/cpp/csp/adapters/websocket/WebsocketClientTypes.cpp b/cpp/csp/adapters/websocket/WebsocketClientTypes.cpp new file mode 100644 index 000000000..ac4492520 --- /dev/null +++ b/cpp/csp/adapters/websocket/WebsocketClientTypes.cpp @@ -0,0 +1,13 @@ +#include "WebsocketClientTypes.h" + +namespace csp { + +INIT_CSP_ENUM( adapters::websocket::ClientStatusType, + "ACTIVE", + "GENERIC_ERROR", + "CONNECTION_FAILED", + "CLOSED", + "MESSAGE_SEND_FAIL", +); + +} \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/WebsocketClientTypes.h b/cpp/csp/adapters/websocket/WebsocketClientTypes.h new file mode 100644 index 000000000..6f1d255f8 --- /dev/null +++ b/cpp/csp/adapters/websocket/WebsocketClientTypes.h @@ -0,0 +1,26 @@ +#pragma once + +#include "csp/core/Enum.h" // or whatever the correct path is + +namespace csp::adapters::websocket { + +struct WebsocketClientStatusTypeTraits +{ + enum _enum : unsigned char + { + ACTIVE = 0, + GENERIC_ERROR = 1, + CONNECTION_FAILED = 2, + CLOSED = 3, + MESSAGE_SEND_FAIL = 4, + + NUM_TYPES + }; + +protected: + _enum m_value; +}; + +using ClientStatusType = Enum<WebsocketClientStatusTypeTraits>; + +} \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp b/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp index 8e856c79e..6bb256488 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp +++ b/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp @@ -3,9 +3,11 @@ namespace csp::adapters::websocket { using namespace csp; -WebsocketEndpoint::WebsocketEndpoint( +WebsocketEndpoint::WebsocketEndpoint( + net::io_context& ioc, Dictionary properties -) : m_properties(properties) +) : m_properties(std::make_shared<Dictionary>(std::move(properties))), + m_ioc(ioc) { }; void WebsocketEndpoint::setOnOpen(void_cb on_open) { m_on_open = std::move(on_open); } @@ -20,17 +22,16 @@ void WebsocketEndpoint::setOnSendFail(string_cb on_send_fail) void WebsocketEndpoint::run() { - - m_ioc.reset(); - if(m_properties.get<bool>("use_ssl")) { + // Owns this ioc object + if(m_properties->get<bool>("use_ssl")) { ssl::context ctx{ssl::context::sslv23}; ctx.set_verify_mode(ssl::context::verify_peer ); ctx.set_default_verify_paths(); - m_session = new WebsocketSessionTLS( + m_session = std::make_shared<WebsocketSessionTLS>( m_ioc, ctx, - &m_properties, + m_properties, m_on_open, m_on_fail, m_on_message, @@ -38,9 +39,9 @@ void WebsocketEndpoint::run() m_on_send_fail ); } else { - m_session = new WebsocketSessionNoTLS( + m_session = std::make_shared<WebsocketSessionNoTLS>( m_ioc, - &m_properties, + m_properties, m_on_open, m_on_fail, m_on_message, @@ -49,23 +50,39 @@ void WebsocketEndpoint::run() ); } m_session->run(); +} - m_ioc.run(); +WebsocketEndpoint::~WebsocketEndpoint() { + try { + // Call stop but explicitly pass false to prevent io_context shutdown + stop(false); + } catch (...) { + // Ignore any exceptions during cleanup + } } -void WebsocketEndpoint::stop() -{ - m_ioc.stop(); - if(m_session) m_session->stop(); +void WebsocketEndpoint::stop( bool stop_ioc ) +{ + if( m_session ) m_session->stop(); + if( stop_ioc ) m_ioc.stop(); } +void WebsocketEndpoint::updateHeaders(csp::Dictionary properties){ + DictionaryPtr headers = m_properties->get<DictionaryPtr>("headers"); + for (auto it = properties.begin(); it != properties.end(); ++it) { + std::string key = it.key(); + auto value = it.value<std::string>(); + headers->update(key, std::move(value)); + } +} -csp::Dictionary& WebsocketEndpoint::getProperties() { +std::shared_ptr<Dictionary> WebsocketEndpoint::getProperties() { return m_properties; } void WebsocketEndpoint::send(const std::string& s) { if(m_session) m_session->send(s); } - +void WebsocketEndpoint::ping() +{ if(m_session) m_session->ping(); } } \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/WebsocketEndpoint.h b/cpp/csp/adapters/websocket/WebsocketEndpoint.h index cfca08742..eea9fed2b 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpoint.h +++ b/cpp/csp/adapters/websocket/WebsocketEndpoint.h @@ -5,6 +5,9 @@ #include <boost/beast/ssl.hpp> #include <boost/beast/websocket.hpp> #include <boost/asio/strand.hpp> +#include <boost/asio/error.hpp> +#include <boost/system/error_code.hpp> +#include <boost/enable_shared_from_this.hpp> #include <csp/engine/Dictionary.h> #include <csp/core/Exception.h> @@ -13,6 +16,7 @@ #include <functional> #include <iostream> #include <string> +#include <memory> namespace csp::adapters::websocket { using namespace csp; @@ -23,6 +27,7 @@ namespace net = boost::asio; // from <boost/asio.hpp> namespace ssl = boost::asio::ssl; // from <boost/asio/ssl.hpp> namespace websocket = beast::websocket; // from <boost/beast/websocket.hpp> using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp> +using error_code = boost::system::error_code; //from <boost/system/error_code.hpp> using string_cb = std::function<void(const std::string&)>; using char_cb = std::function<void(void*, size_t)>; @@ -30,7 +35,9 @@ using void_cb = std::function<void()>; class BaseWebsocketSession { public: + virtual ~BaseWebsocketSession() = default; virtual void stop() { }; + virtual void ping() { }; virtual void send( const std::string& ) { }; virtual void do_read() { }; virtual void do_write(const std::string& ) { }; @@ -38,24 +45,28 @@ class BaseWebsocketSession { }; template<class Derived> -class WebsocketSession : public BaseWebsocketSession { +class WebsocketSession : + public BaseWebsocketSession, + public std::enable_shared_from_this<Derived> +{ public: WebsocketSession( net::io_context& ioc, - Dictionary* properties, - void_cb& on_open, - string_cb& on_fail, - char_cb& on_message, - void_cb& on_close, - string_cb& on_send_fail - ) : m_resolver( net::make_strand( ioc ) ), - m_properties( properties ), - m_on_open( on_open ), - m_on_fail( on_fail ), - m_on_message( on_message ), - m_on_close( on_close ), - m_on_send_fail( on_send_fail ) - { }; + std::shared_ptr<Dictionary> properties, + void_cb on_open, + string_cb on_fail, + char_cb on_message, + void_cb on_close, + string_cb on_send_fail + ) : m_resolver(net::make_strand(ioc)), + m_properties(properties), + m_on_open(std::move(on_open)), + m_on_fail(std::move(on_fail)), + m_on_message(std::move(on_message)), + m_on_close(std::move(on_close)), + m_on_send_fail(std::move(on_send_fail)) + { } + ~WebsocketSession() override = default; Derived& derived(){ return static_cast<Derived&>(*this); } @@ -81,53 +92,66 @@ class WebsocketSession : public BaseWebsocketSession { } void do_read() override { + auto self = std::static_pointer_cast<Derived>(this->shared_from_this()); derived().ws().async_read( - m_buffer, - [ this ]( beast::error_code ec, std::size_t bytes_transfered ) - { handle_message( ec, bytes_transfered ); } + self->m_buffer, + [ self ](beast::error_code ec, std::size_t bytes_transfered) { + self->handle_message(ec, bytes_transfered); + } ); } + void ping() override { + auto self = std::static_pointer_cast<Derived>(this->shared_from_this()); + derived().ws().async_ping({}, + [ self ](beast::error_code ec) { + if(ec) self->m_on_send_fail("Failed to ping"); + }); + } + void stop() override { - derived().ws().async_close( websocket::close_code::normal, [ this ]( beast::error_code ec ) { - if(ec) CSP_THROW(RuntimeException, ec.message()); - m_on_close(); + auto self = std::static_pointer_cast<Derived>(this->shared_from_this()); + derived().ws().async_close( websocket::close_code::normal, [ self ]( beast::error_code ec ) { + if(ec) self->m_on_fail(ec.message()); + self -> m_on_close(); }); } - void send( const std::string& s ) override + void send(const std::string& s) override { + auto self = std::static_pointer_cast<Derived>(this->shared_from_this()); net::post( derived().ws().get_executor(), - [this, s]() + [ self, s]() { - m_queue.push_back(s); - if (m_queue.size() > 1) return; - do_write(m_queue.front()); + self->m_queue.push_back(s); + if (self->m_queue.size() > 1) return; + self->do_write(self->m_queue.front()); } ); } void do_write(const std::string& s) override { + auto self = std::static_pointer_cast<Derived>(this->shared_from_this()); derived().ws().async_write( net::buffer(s), - [this](beast::error_code ec, std::size_t bytes_transfered) + [self](beast::error_code ec, std::size_t bytes_transfered) { - // add logging here? - m_queue.erase(m_queue.begin()); + self->m_queue.erase(self->m_queue.begin()); boost::ignore_unused(bytes_transfered); - if(ec) m_on_send_fail(ec.message()); - if(m_queue.size() >0) do_write(m_queue.front()); + if(ec) self->m_on_send_fail(ec.message()); + if(self->m_queue.size() > 0) + self->do_write(self->m_queue.front()); } ); - } +} public: tcp::resolver m_resolver; - Dictionary* m_properties; + std::shared_ptr<Dictionary> m_properties; void_cb m_on_open; string_cb m_on_fail; char_cb m_on_message; @@ -142,7 +166,7 @@ class WebsocketSessionNoTLS final: public WebsocketSession<WebsocketSessionNoTLS public: WebsocketSessionNoTLS( net::io_context& ioc, - Dictionary* properties, + std::shared_ptr<Dictionary> properties, void_cb& on_open, string_cb& on_fail, char_cb& on_message, @@ -161,57 +185,58 @@ class WebsocketSessionNoTLS final: public WebsocketSession<WebsocketSessionNoTLS { } void run() override { + auto self = std::static_pointer_cast<WebsocketSessionNoTLS>(this->shared_from_this()); m_resolver.async_resolve( m_properties->get<std::string>("host").c_str(), m_properties->get<std::string>("port").c_str(), - [this]( beast::error_code ec, tcp::resolver::results_type results ) { + [ self ]( beast::error_code ec, tcp::resolver::results_type results ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } // Set the timeout for the operation - beast::get_lowest_layer(m_ws).expires_after(std::chrono::seconds(5)); + beast::get_lowest_layer(self->m_ws).expires_after(std::chrono::seconds(5)); // Make the connection on the IP address we get from a lookup - beast::get_lowest_layer(m_ws).async_connect( + beast::get_lowest_layer(self->m_ws).async_connect( results, - [this]( beast::error_code ec, tcp::resolver::results_type::endpoint_type ep ) + [self]( beast::error_code ec, tcp::resolver::results_type::endpoint_type ep ) { // Turn off the timeout on the tcp_stream, because // the websocket stream has its own timeout system. if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - beast::get_lowest_layer(m_ws).expires_never(); + beast::get_lowest_layer(self->m_ws).expires_never(); - m_ws.set_option( + self->m_ws.set_option( websocket::stream_base::timeout::suggested( beast::role_type::client)); - m_ws.set_option(websocket::stream_base::decorator( - [this](websocket::request_type& req) + self->m_ws.set_option(websocket::stream_base::decorator( + [self](websocket::request_type& req) { - set_headers(req); + self -> set_headers(req); req.set(http::field::user_agent, "CSP WebsocketEndpoint"); } )); - std::string host_ = m_properties->get<std::string>("host") + ':' + std::to_string(ep.port()); - m_ws.async_handshake( + std::string host_ = self->m_properties->get<std::string>("host") + ':' + std::to_string(ep.port()); + self->m_ws.async_handshake( host_, - m_properties->get<std::string>("route"), - [this]( beast::error_code ec ) { + self->m_properties->get<std::string>("route"), + [self]( beast::error_code ec ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - m_on_open(); - m_ws.async_read( - m_buffer, - [ this ]( beast::error_code ec, std::size_t bytes_transfered ) - { handle_message( ec, bytes_transfered ); } + self->m_on_open(); + self->m_ws.async_read( + self->m_buffer, + [ self ]( beast::error_code ec, std::size_t bytes_transfered ) + { self->handle_message( ec, bytes_transfered ); } ); } ); @@ -232,7 +257,7 @@ class WebsocketSessionTLS final: public WebsocketSession<WebsocketSessionTLS> { WebsocketSessionTLS( net::io_context& ioc, ssl::context& ctx, - Dictionary* properties, + std::shared_ptr<Dictionary> properties, void_cb& on_open, string_cb& on_fail, char_cb& on_message, @@ -251,73 +276,74 @@ class WebsocketSessionTLS final: public WebsocketSession<WebsocketSessionTLS> { { } void run() override { + auto self = std::static_pointer_cast<WebsocketSessionTLS>(this->shared_from_this()); m_resolver.async_resolve( m_properties->get<std::string>("host").c_str(), m_properties->get<std::string>("port").c_str(), - [this]( beast::error_code ec, tcp::resolver::results_type results ) { + [self]( beast::error_code ec, tcp::resolver::results_type results ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } // Set the timeout for the operation - beast::get_lowest_layer(m_ws).expires_after(std::chrono::seconds(5)); + beast::get_lowest_layer(self->m_ws).expires_after(std::chrono::seconds(5)); // Make the connection on the IP address we get from a lookup - beast::get_lowest_layer(m_ws).async_connect( + beast::get_lowest_layer(self->m_ws).async_connect( results, - [this]( beast::error_code ec, tcp::resolver::results_type::endpoint_type ep ) + [self]( beast::error_code ec, tcp::resolver::results_type::endpoint_type ep ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } if(! SSL_set_tlsext_host_name( - m_ws.next_layer().native_handle(), - m_properties->get<std::string>("host").c_str())) + self->m_ws.next_layer().native_handle(), + self->m_properties->get<std::string>("host").c_str())) { ec = beast::error_code(static_cast<int>(::ERR_get_error()), net::error::get_ssl_category()); - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - m_complete_host = m_properties->get<std::string>("host") + ':' + std::to_string(ep.port()); + self->m_complete_host = self->m_properties->get<std::string>("host") + ':' + std::to_string(ep.port()); // ssl handler - m_ws.next_layer().async_handshake( + self->m_ws.next_layer().async_handshake( ssl::stream_base::client, - [this]( beast::error_code ec ) { + [self]( beast::error_code ec ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - beast::get_lowest_layer(m_ws).expires_never(); + beast::get_lowest_layer(self->m_ws).expires_never(); // Set suggested timeout settings for the websocket - m_ws.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); + self->m_ws.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); // Set a decorator to change the User-Agent of the handshake - m_ws.set_option(websocket::stream_base::decorator( - [this](websocket::request_type& req) + self->m_ws.set_option(websocket::stream_base::decorator( + [self](websocket::request_type& req) { - set_headers(req); + self->set_headers(req); req.set(http::field::user_agent, "CSP WebsocketAdapter"); })); - m_ws.async_handshake( - m_complete_host, - m_properties->get<std::string>("route"), - [this]( beast::error_code ec ) { + self->m_ws.async_handshake( + self->m_complete_host, + self->m_properties->get<std::string>("route"), + [self]( beast::error_code ec ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - m_on_open(); - m_ws.async_read( - m_buffer, - [ this ]( beast::error_code ec, std::size_t bytes_transfered ) - { handle_message( ec, bytes_transfered ); } + self->m_on_open(); + self->m_ws.async_read( + self->m_buffer, + [ self ]( beast::error_code ec, std::size_t bytes_transfered ) + { self->handle_message( ec, bytes_transfered ); } ); } ); @@ -340,23 +366,26 @@ class WebsocketSessionTLS final: public WebsocketSession<WebsocketSessionTLS> { class WebsocketEndpoint { public: - WebsocketEndpoint( Dictionary properties ); - virtual ~WebsocketEndpoint() { }; + WebsocketEndpoint( net::io_context& ioc, Dictionary properties ); + ~WebsocketEndpoint(); void setOnOpen(void_cb on_open); void setOnFail(string_cb on_fail); void setOnMessage(char_cb on_message); void setOnClose(void_cb on_close); void setOnSendFail(string_cb on_send_fail); - Dictionary& getProperties(); + void updateHeaders(Dictionary properties); + std::shared_ptr<Dictionary> getProperties(); + // Dictionary& getProperties(); void run(); - void stop(); + void stop( bool stop_ioc = true); void send(const std::string& s); + void ping(); private: - Dictionary m_properties; - BaseWebsocketSession* m_session; - net::io_context m_ioc; + std::shared_ptr<Dictionary> m_properties; + std::shared_ptr<BaseWebsocketSession> m_session; + net::io_context& m_ioc; void_cb m_on_open; string_cb m_on_fail; char_cb m_on_message; diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp new file mode 100644 index 000000000..b960af825 --- /dev/null +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp @@ -0,0 +1,432 @@ +#include <csp/adapters/websocket/WebsocketEndpointManager.h> + +namespace csp::adapters::websocket { + +WebsocketEndpointManager::WebsocketEndpointManager( ClientAdapterManager* mgr, const Dictionary & properties, Engine* engine ) +: m_num_threads( static_cast<size_t>(properties.get<int64_t>("num_threads")) ), + m_ioc( m_num_threads ), + m_engine( engine ), + m_strand( boost::asio::make_strand(m_ioc) ), + m_mgr( mgr ), + m_updateAdapter( nullptr ), + m_properties( properties ), + m_work_guard(boost::asio::make_work_guard(m_ioc)), + m_dynamic( properties.get<bool>("dynamic") ){ + // Total number of subscribe and send function calls, set on the adapter manager + // when is it created. Note, that some of the input adapters might have been + // pruned from the graph and won't get created. + auto input_size = static_cast<size_t>(properties.get<int64_t>("subscribe_calls")); + m_inputAdapters.resize(input_size, nullptr); + m_consumer_endpoints.resize(input_size); + // send_calls + auto output_size = static_cast<size_t>(properties.get<int64_t>("send_calls")); + m_outputAdapters.resize(output_size, nullptr); + m_producer_endpoints.resize(output_size); + + // We choose to not automatically size m_connectionRequestAdapters + // since the index there is not meaningful, + // producers and subscribers are combined. + // We just hold onto their pointers. +}; + +WebsocketEndpointManager::~WebsocketEndpointManager() +{ +} + +void WebsocketEndpointManager::start(DateTime starttime, DateTime endtime) { + m_ioc.reset(); + if( !m_dynamic ){ + boost::asio::post(m_strand, [this]() { + // We subscribe for both the subscribe and send calls + // But we probably should check here. + if( m_outputAdapters.size() == 1) + handleConnectionRequest(Dictionary(m_properties), 0, false); + // If we have an input adapter call AND it's not pruned. + if( m_inputAdapters.size() == 1 && !adapterPruned(0)) + handleConnectionRequest(Dictionary(m_properties), 0, true); + }); + } + for (size_t i = 0; i < m_num_threads; ++i) { + m_threads.emplace_back(std::make_unique<std::thread>([this]() { + m_ioc.run(); + })); + } +}; + +bool WebsocketEndpointManager::adapterPruned( size_t caller_id ){ + return m_inputAdapters[caller_id] == nullptr; +}; + +void WebsocketEndpointManager::send(const std::string& value, const size_t& caller_id) { + const auto& endpoints = m_producer_endpoints[caller_id]; + // For each endpoint this producer is connected to + for (const auto& endpoint_id : endpoints) { + // Double check the endpoint exists and producer is still valid + if(publishesToEndpoint(caller_id, endpoint_id)) { + auto it = m_endpoints.find(endpoint_id); + if( it != m_endpoints.end()) + it->second.get()->send(value); + } + } +}; + +void WebsocketEndpointManager::removeEndpointForCallerId(const std::string& endpoint_id, bool is_consumer, size_t validated_id) +{ + if (is_consumer) { + WebsocketEndpointManager::removeConsumer(endpoint_id, validated_id); + } else { + WebsocketEndpointManager::removeProducer(endpoint_id, validated_id); + } + if (canRemoveEndpoint(endpoint_id)) + shutdownEndpoint(endpoint_id); +} + +void WebsocketEndpointManager::shutdownEndpoint(const std::string& endpoint_id) { + // This functions should only be called from the thread running m_ioc + // Cancel any pending reconnection attempts + if (auto config_it = m_endpoint_configs.find(endpoint_id); + config_it != m_endpoint_configs.end()) { + config_it->second.reconnect_timer->cancel(); + m_endpoint_configs.erase(config_it); + } + + // Stop and remove the endpoint + // No need to stop, destructo handles it + if (auto endpoint_it = m_endpoints.find(endpoint_id); endpoint_it != m_endpoints.end()) + m_endpoints.erase(endpoint_it); + std::stringstream ss; + ss << "No more connections for endpoint={" << endpoint_id << "} Shutting down..."; + m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::CLOSED, ss.str()); +} + +void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id, + std::unique_ptr<WebsocketEndpoint> endpoint, + std::string payload, + bool persist, + bool is_consumer, + size_t validated_id) +{ + // Store the endpoint first + auto& stored_endpoint = m_endpoints[endpoint_id] = std::move(endpoint); + + stored_endpoint->setOnOpen([this, endpoint_id, endpoint = stored_endpoint.get(), payload=std::move(payload), persist, is_consumer, validated_id]() { + auto [iter, inserted] = m_endpoint_configs.try_emplace(endpoint_id, m_ioc); + auto& config = iter->second; + config.connected = true; + config.attempting_reconnect = false; + + // Send consumer payloads + const auto& consumers = m_endpoint_consumers[endpoint_id]; + for (size_t i = 0; i < config.consumer_payloads.size(); ++i) { + if (!config.consumer_payloads[i].empty() && + i < consumers.size() && consumers[i]) { + endpoint->send(config.consumer_payloads[i]); + } + } + + // Send producer payloads + const auto& producers = m_endpoint_producers[endpoint_id]; + for (size_t i = 0; i < config.producer_payloads.size(); ++i) { + if (!config.producer_payloads[i].empty() && + i < producers.size() && producers[i]) { + endpoint->send(config.producer_payloads[i]); + } + } + // should only happen if persist is False + if ( !payload.empty() ) + endpoint -> send(payload); + + m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::ACTIVE, + "Connected successfully for endpoint={" + endpoint_id +"}"); + // We remove the caller id, if it was the only one, then we shut down the endpoint + if( !persist ) + removeEndpointForCallerId(endpoint_id, is_consumer, validated_id); + }); + + stored_endpoint->setOnFail([this, endpoint_id](const std::string& reason) { + handleEndpointFailure(endpoint_id, reason, ClientStatusType::CONNECTION_FAILED); + }); + + stored_endpoint->setOnClose([this, endpoint_id]() { + // If we didn't close it ourselves + if (auto config_it = m_endpoint_configs.find(endpoint_id); config_it != m_endpoint_configs.end()) + handleEndpointFailure(endpoint_id, "Connection closed", ClientStatusType::CLOSED); + }); + stored_endpoint->setOnMessage([this, endpoint_id](void* data, size_t len) { + // Here we need to route to all active consumers for this endpoint + const auto& consumers = m_endpoint_consumers[endpoint_id]; + + // For each active consumer, we need to send to their input adapter + PushBatch batch( m_engine -> rootEngine() ); // TODO is this right? + for (size_t consumer_id = 0; consumer_id < consumers.size(); ++consumer_id) { + if (consumers[consumer_id]) { + std::vector<uint8_t> data_copy(static_cast<uint8_t*>(data), + static_cast<uint8_t*>(data) + len); + auto tup = std::tuple<std::string, void*>{endpoint_id, data_copy.data()}; + m_inputAdapters[consumer_id] -> processMessage( std::move(tup), len, &batch ); + } + } + }); + stored_endpoint -> setOnSendFail( + [ this, endpoint_id ]( const std::string& s ) { + std::stringstream ss; + ss << "Error: " << s << " for " << endpoint_id; + m_mgr -> pushStatus( StatusLevel::ERROR, ClientStatusType::MESSAGE_SEND_FAIL, ss.str() ); + } + ); + stored_endpoint -> run(); +}; + + +void WebsocketEndpointManager::handleEndpointFailure(const std::string& endpoint_id, + const std::string& reason, ClientStatusType status_type) { + // If there are any active consumers/producers, try to reconnect + if (!canRemoveEndpoint(endpoint_id)) { + auto [iter, inserted] = m_endpoint_configs.try_emplace(endpoint_id, m_ioc); + auto& config = iter->second; + config.connected = false; + + if (!config.attempting_reconnect) { + config.attempting_reconnect = true; + + // Schedule reconnection attempt + config.reconnect_timer->expires_after(config.reconnect_interval); + config.reconnect_timer->async_wait([this, endpoint_id](const error_code& ec) { + // boost::asio::post(m_ioc, [this, endpoint_id]() { + // If we still want to subscribe to this endpoint + if (auto it = m_endpoints.find(endpoint_id); + it != m_endpoints.end()) { + auto config_it = m_endpoint_configs.find(endpoint_id); + if (config_it != m_endpoint_configs.end()) { + auto& config = config_it -> second; + // We are no longer attempting to reconnect + config.attempting_reconnect = false; + } + it->second->run(); // Attempt to reconnect + } + }); + } + } else { + // No active consumers/producers, clean up the endpoint + m_endpoints.erase(endpoint_id); + m_endpoint_configs.erase(endpoint_id); + } + + std::stringstream ss; + ss << "Connection Failure for endpoint={" << endpoint_id << "} Due to: " << reason; + if ( status_type == ClientStatusType::CLOSED || status_type == ClientStatusType::ACTIVE ) + m_mgr -> pushStatus(StatusLevel::INFO, status_type, ss.str()); + else{ + m_mgr -> pushStatus(StatusLevel::ERROR, status_type, ss.str()); + } +}; + +void WebsocketEndpointManager::handleConnectionRequest(const Dictionary & properties, size_t validated_id, bool is_subscribe) +{ + // This should only get called from the thread running + // m_ioc. This allows us to avoid locks on internal data + // structures + // std::cout << " HEY YA\n"; + auto endpoint_id = properties.get<std::string>("uri"); + // std::cout << endpoint_id; + autogen::ActionType action = autogen::ActionType::create( properties.get<std::string>("action") ); + // std::cout << action.asString() << "\n"; + // // Change headers if needed here! + switch(action.enum_value()) { + case autogen::ActionType::enum_::CONNECT: { + auto persistent = properties.get<bool>("persistent"); + auto reconnect_interval = properties.get<TimeDelta>("reconnect_interval"); + // Update endpoint config + auto& config = m_endpoint_configs.try_emplace(endpoint_id, m_ioc).first->second; + + config.reconnect_interval = std::chrono::milliseconds( + reconnect_interval.asMilliseconds() + ); + std::string payload = ""; + bool has_payload = properties.tryGet<std::string>("on_connect_payload", payload); + + if (has_payload && !payload.empty() && persistent) { + auto& payloads = is_subscribe ? config.consumer_payloads : config.producer_payloads; + if (payloads.size() <= validated_id) { + payloads.resize(validated_id + 1); + } + payloads[validated_id] = std::move(payload); // Move to config + } + + if ( persistent ){ + if (is_subscribe) { + WebsocketEndpointManager::addConsumer(endpoint_id, validated_id); + } else { + WebsocketEndpointManager::addProducer(endpoint_id, validated_id); + } + } + + bool is_new_endpoint = !m_endpoints.contains(endpoint_id); + if (is_new_endpoint) { + auto endpoint = std::make_unique<WebsocketEndpoint>(m_ioc, properties); + // We can safely move payload regardless - if it was never written to, it's just an empty string + WebsocketEndpointManager::setupEndpoint(endpoint_id, std::move(endpoint), + (has_payload && !payload.empty() && persistent) ? "" : std::move(payload), + persistent, is_subscribe, validated_id ); + } + else{ + if( !persistent && !payload.empty() ) + m_endpoints[endpoint_id]->send(payload); + // Conscious decision to let non-persisten connection + // results to update the header + auto headers = *properties.get<DictionaryPtr>("headers"); + m_endpoints[endpoint_id]->updateHeaders(std::move(headers)); + } + // } + break; + } + + case csp::autogen::ActionType::enum_::DISCONNECT: { + // Clear persistence flag for this caller + removeEndpointForCallerId(endpoint_id, is_subscribe, validated_id); + break; + } + + case csp::autogen::ActionType::enum_::PING: { + // Only ping if the caller is actually connected to this endpoint + auto& consumers = m_endpoint_consumers[endpoint_id]; + auto& producers = m_endpoint_producers[endpoint_id]; + + if ( ( is_subscribe && validated_id < consumers.size() && consumers[validated_id] ) || + ( !is_subscribe && validated_id < producers.size() && producers[validated_id] ) ) { + if (auto it = m_endpoints.find(endpoint_id); it != m_endpoints.end()) { + it->second.get()->ping(); + } + } + break; + } + } +}; + +WebsocketEndpoint * WebsocketEndpointManager::getNonDynamicEndpoint(){ + // Should only be called if dynamic = False + if (!m_endpoints.empty()) { + return m_endpoints.begin()->second.get(); + } + return nullptr; +} + +void WebsocketEndpointManager::addConsumer(const std::string& endpoint_id, size_t caller_id) { + ensureVectorSize(m_endpoint_consumers[endpoint_id], caller_id); + m_endpoint_consumers[endpoint_id][caller_id] = true; + + m_consumer_endpoints[caller_id].insert(endpoint_id); +}; + +void WebsocketEndpointManager::addProducer(const std::string& endpoint_id, size_t caller_id) { + ensureVectorSize(m_endpoint_producers[endpoint_id], caller_id); + m_endpoint_producers[endpoint_id][caller_id] = true; + + m_producer_endpoints[caller_id].insert(endpoint_id); +}; + +bool WebsocketEndpointManager::canRemoveEndpoint(const std::string& endpoint_id) { + const auto& consumers = m_endpoint_consumers[endpoint_id]; + const auto& producers = m_endpoint_producers[endpoint_id]; + + // Check if any true values exist in either vector + return std::none_of(consumers.begin(), consumers.end(), [](bool b) { return b; }) && + std::none_of(producers.begin(), producers.end(), [](bool b) { return b; }); +}; + +void WebsocketEndpointManager::removeConsumer(const std::string& endpoint_id, size_t caller_id) { + auto& consumers = m_endpoint_consumers[endpoint_id]; + // Possibility it might not be subscribed, + // so we have this check. + if (caller_id < consumers.size()) { + consumers[caller_id] = false; + } + // We initialize these upfront, this will be valid. + m_consumer_endpoints[caller_id].erase(endpoint_id); +}; + +void WebsocketEndpointManager::removeProducer(const std::string& endpoint_id, size_t caller_id) { + auto& producers = m_endpoint_producers[endpoint_id]; + // Possibility it might not be publihsing to + // so we have this check. + if (caller_id < producers.size()) { + producers[caller_id] = false; + } + + // We initialize these upfront, this will be valid. + m_producer_endpoints[caller_id].erase(endpoint_id); +}; + + +void WebsocketEndpointManager::stop() { + // Stop all endpoints + // Endpoints running on m_ioc thread, + // So we call stop there + boost::asio::post(m_strand, [this]() { + for (auto& [endpoint_id, _] : m_endpoints) { + shutdownEndpoint(endpoint_id); + } + }); + // Stop the work guard to allow the io_context to complete + m_work_guard.reset(); + m_ioc.stop(); + + // Wait for all threads to finish + for (auto& thread : m_threads) { + if (thread && thread->joinable()) { + thread->join(); + } + } + + // Clear threads before other members are destroyed + m_threads.clear(); +}; + +PushInputAdapter* WebsocketEndpointManager::getInputAdapter(CspTypePtr & type, PushMode pushMode, const Dictionary & properties) +{ + auto caller_id = properties.get<int64_t>("caller_id"); + size_t validated_id = validateCallerId(caller_id); + auto input_adapter = m_engine -> createOwnedObject<ClientInputAdapter>( + type, + pushMode, + properties, + m_dynamic + ); + m_inputAdapters[validated_id] = input_adapter; + return m_inputAdapters[validated_id]; +}; + +OutputAdapter* WebsocketEndpointManager::getOutputAdapter( const Dictionary & properties ) +{ + auto caller_id = properties.get<int64_t>("caller_id"); + size_t validated_id = validateCallerId(caller_id); + assert(!properties.get<bool>("is_subscribe")); + assert(m_outputAdapters.size() == validated_id); + + auto output_adapter = m_engine -> createOwnedObject<ClientOutputAdapter>( this, validated_id, m_ioc, m_strand ); + m_outputAdapters[validated_id] = output_adapter; + return m_outputAdapters[validated_id]; +}; + +OutputAdapter * WebsocketEndpointManager::getHeaderUpdateAdapter() +{ + if (m_updateAdapter == nullptr) + m_updateAdapter = m_engine -> createOwnedObject<ClientHeaderUpdateOutputAdapter>( this, m_strand ); + + return m_updateAdapter; +}; + +OutputAdapter * WebsocketEndpointManager::getConnectionRequestAdapter( const Dictionary & properties ) +{ + auto caller_id = properties.get<int64_t>("caller_id"); + auto is_subscribe = properties.get<bool>("is_subscribe"); + + auto* adapter = m_engine->createOwnedObject<ClientConnectionRequestAdapter>( + this, is_subscribe, caller_id, m_strand + ); + m_connectionRequestAdapters.push_back(adapter); + + return adapter; +}; + +} diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.h b/cpp/csp/adapters/websocket/WebsocketEndpointManager.h new file mode 100644 index 000000000..736413e03 --- /dev/null +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.h @@ -0,0 +1,152 @@ +#ifndef WEBSOCKET_ENDPOINT_MANAGER_H +#define WEBSOCKET_ENDPOINT_MANAGER_H + +#include <boost/asio.hpp> +#include <csp/adapters/websocket/WebsocketClientTypes.h> +#include <csp/adapters/websocket/WebsocketEndpoint.h> +#include <csp/adapters/websocket/ClientAdapterManager.h> +#include <csp/adapters/websocket/ClientInputAdapter.h> +#include <csp/adapters/websocket/ClientOutputAdapter.h> +#include <csp/adapters/websocket/ClientHeaderUpdateAdapter.h> +#include <csp/adapters/websocket/ClientConnectionRequestAdapter.h> +#include <csp/core/Enum.h> +#include <csp/core/Hash.h> +#include <csp/engine/AdapterManager.h> +#include <csp/engine/Dictionary.h> +#include <csp/engine/PushInputAdapter.h> +#include <csp/core/Platform.h> +#include <thread> +#include <chrono> +#include <atomic> +#include <mutex> +#include <condition_variable> +#include <iomanip> +#include <iostream> +#include <vector> +#include <unordered_set> +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <vector> +#include <chrono> +#include <optional> +#include <functional> + +namespace csp::adapters::websocket { +using namespace csp; +class WebsocketEndpoint; + +class ClientAdapterManager; +class ClientOutputAdapter; +class ClientConnectionRequestAdapter; +class ClientHeaderUpdateOutputAdapter; + +struct ConnectPayloads { + std::vector<std::string> consumer_payloads; + std::vector<std::string> producer_payloads; +}; + +struct EndpointConfig { + std::chrono::milliseconds reconnect_interval; + std::unique_ptr<boost::asio::steady_timer> reconnect_timer; + bool attempting_reconnect{false}; + bool connected{false}; + + // Payloads for different client types + std::vector<std::string> consumer_payloads; + std::vector<std::string> producer_payloads; + + explicit EndpointConfig(boost::asio::io_context& ioc) + : reconnect_timer(std::make_unique<boost::asio::steady_timer>(ioc)) {} +}; + +// Callbacks for endpoint events +struct EndpointCallbacks { + std::function<void(const std::string&)> onOpen; + std::function<void(const std::string&, const std::string&)> onFail; + std::function<void(const std::string&)> onClose; + std::function<void(const std::string&, const std::string&)> onSendFail; + std::function<void(const std::string&, void*, size_t)> onMessage; +}; + +class WebsocketEndpointManager { +public: + explicit WebsocketEndpointManager(ClientAdapterManager* mgr, const Dictionary & properties, Engine* engine); + ~WebsocketEndpointManager(); + void send(const std::string& value, const size_t& caller_id); + // Whether the input adapter (subscribe) given by a specific caller_id was pruned + bool adapterPruned( size_t caller_id ); + // Whether the output adapater (publish) given by a specific caller_id publishes to a given endpoint + + void start(DateTime starttime, DateTime endtime); + void stop(); + + void handleConnectionRequest(const Dictionary & properties, size_t validated_id, bool is_subscribe); + void handleEndpointFailure(const std::string& endpoint_id, const std::string& reason, ClientStatusType status_type); + + void setupEndpoint(const std::string& endpoint_id, std::unique_ptr<WebsocketEndpoint> endpoint, std::string payload, bool persist, bool is_consumer, size_t validated_id); + void shutdownEndpoint(const std::string& endpoint_id); + + void addConsumer(const std::string& endpoint_id, size_t caller_id); + void addProducer(const std::string& endpoint_id, size_t caller_id); + bool canRemoveEndpoint(const std::string& endpoint_id); + + void removeEndpointForCallerId(const std::string& endpoint_id, bool is_consumer, size_t validated_id); + void removeConsumer(const std::string& endpoint_id, size_t caller_id); + void removeProducer(const std::string& endpoint_id, size_t caller_id); + + WebsocketEndpoint * getNonDynamicEndpoint(); + PushInputAdapter * getInputAdapter( CspTypePtr & type, PushMode pushMode, const Dictionary & properties ); + OutputAdapter * getOutputAdapter( const Dictionary & properties ); + OutputAdapter * getHeaderUpdateAdapter(); + OutputAdapter * getConnectionRequestAdapter( const Dictionary & properties ); +private: + inline size_t validateCallerId(int64_t caller_id) const { + if (caller_id < 0) { + CSP_THROW(ValueError, "caller_id cannot be negative: " << caller_id); + } + return static_cast<size_t>(caller_id); + } + inline void ensureVectorSize(std::vector<bool>& vec, size_t caller_id) { + if (vec.size() <= caller_id) { + vec.resize(caller_id + 1, false); + } + } + // Whether the output adapater (publish) given by a specific caller_id publishes to a given endpoint + inline bool publishesToEndpoint(const size_t caller_id, const std::string& endpoint_id){ + auto config_it = m_endpoint_configs.find(endpoint_id); + if( config_it == m_endpoint_configs.end() || !config_it->second.connected ) + return false; + + return caller_id < m_endpoint_producers[endpoint_id].size() && + m_endpoint_producers[endpoint_id][caller_id]; + } + size_t m_num_threads; + net::io_context m_ioc; + Engine* m_engine; + boost::asio::strand<boost::asio::io_context::executor_type> m_strand; + ClientAdapterManager* m_mgr; + ClientHeaderUpdateOutputAdapter* m_updateAdapter; + std::vector<std::unique_ptr<std::thread>> m_threads; + Dictionary m_properties; + std::vector<ClientConnectionRequestAdapter*> m_connectionRequestAdapters; + + // Bidirectional mapping using vectors since caller_ids are sequential + // Maybe not efficient? Should be good for small number of edges though + std::unordered_map<std::string, std::vector<bool>> m_endpoint_consumers; // endpoint_id -> vector[caller_id] for consuemrs + std::unordered_map<std::string, std::vector<bool>> m_endpoint_producers; // endpoint_id -> vector[caller_id] for producers + + // Quick lookup for caller's endpoints + std::vector< std::unordered_set<std::string> > m_consumer_endpoints; // caller_id -> set of endpoints they consume from + std::vector< std::unordered_set<std::string> > m_producer_endpoints; // caller_id -> set of endpoints they produce to + boost::asio::executor_work_guard<boost::asio::io_context::executor_type> m_work_guard; + std::unordered_map<std::string, std::unique_ptr<WebsocketEndpoint>> m_endpoints; + std::unordered_map<std::string, EndpointConfig> m_endpoint_configs; + std::vector<ClientInputAdapter*> m_inputAdapters; + std::vector<ClientOutputAdapter*> m_outputAdapters; + bool m_dynamic; +}; + +} +#endif \ No newline at end of file diff --git a/cpp/csp/python/Conversions.h b/cpp/csp/python/Conversions.h index 2422aaea1..a3ada6b2d 100644 --- a/cpp/csp/python/Conversions.h +++ b/cpp/csp/python/Conversions.h @@ -666,6 +666,30 @@ inline Dictionary fromPython( PyObject * o ) return out; } +template<> +inline std::vector<Dictionary> fromPython(PyObject* o) +{ + if (!PyList_Check(o)) + CSP_THROW(TypeError, "List of dictionaries conversion expected type list got " << Py_TYPE(o)->tp_name); + + Py_ssize_t size = PyList_GET_SIZE(o); + std::vector<Dictionary> out; + out.reserve(size); + + for (Py_ssize_t i = 0; i < size; ++i) + { + PyObject* item = PyList_GET_ITEM(o, i); + + // Skip None values like in Dictionary conversion + if (item == Py_None) + continue; + + out.emplace_back(fromPython<Dictionary>(item)); + } + + return out; +} + template<> inline std::vector<Dictionary::Data> fromPython( PyObject * o ) { diff --git a/cpp/csp/python/adapters/websocketadapterimpl.cpp b/cpp/csp/python/adapters/websocketadapterimpl.cpp index d636932da..db6591271 100644 --- a/cpp/csp/python/adapters/websocketadapterimpl.cpp +++ b/cpp/csp/python/adapters/websocketadapterimpl.cpp @@ -8,6 +8,7 @@ #include <csp/python/PyInputAdapterWrapper.h> #include <csp/python/PyOutputAdapterWrapper.h> +#include <iostream> using namespace csp::adapters::websocket; namespace csp::python @@ -45,7 +46,11 @@ static OutputAdapter * create_websocket_output_adapter( csp::AdapterManager * ma auto * websocketManager = dynamic_cast<ClientAdapterManager*>( manager ); if( !websocketManager ) CSP_THROW( TypeError, "Expected WebsocketClientAdapterManager" ); - return websocketManager -> getOutputAdapter(); + PyObject * pyProperties; + if( !PyArg_ParseTuple( args, "O!", + &PyDict_Type, &pyProperties ) ) + CSP_THROW( PythonPassthrough, "" ); + return websocketManager -> getOutputAdapter(fromPython<Dictionary>( pyProperties )); } static OutputAdapter * create_websocket_header_update_adapter( csp::AdapterManager * manager, PyEngine * pyengine, PyObject * args ) @@ -56,10 +61,40 @@ static OutputAdapter * create_websocket_header_update_adapter( csp::AdapterManag return websocketManager -> getHeaderUpdateAdapter(); } +static OutputAdapter * create_websocket_connection_request_adapter( csp::AdapterManager * manager, PyEngine * pyengine, PyObject * args ) +{ + // std::cout << "hereeeee33ee" << "\n"; + PyObject * pyProperties; + // PyObject * type; + auto * websocketManager = dynamic_cast<ClientAdapterManager*>( manager ); + if( !websocketManager ) + CSP_THROW( TypeError, "Expected WebsocketClientAdapterManager" ); + + if( !PyArg_ParseTuple( args, "O!", + &PyDict_Type, &pyProperties ) ) + CSP_THROW( PythonPassthrough, "" ); + // std::cout << "hereeeee334444ee" << "\n"; + return websocketManager -> getConnectionRequestAdapter( fromPython<Dictionary>( pyProperties ) ); + + + // TODO + // Here I think we should have a websocket connection manager + // that will handle the connections and endpoint management + // It will create the connection request output adapter + // That output adapter, when it ticks, with a list of python Dictionary + // will then use the boost beast 'post' function to schedule, on the + // io context, a callback to process that dict (on the websocket connection manager!!!) and handle the endpoint manipulation appropriately + + // that websocket connection manager will run the thread with the io context + // being run. Move it away from clientAdapterManager +} + REGISTER_ADAPTER_MANAGER( _websocket_adapter_manager, create_websocket_adapter_manager ); REGISTER_INPUT_ADAPTER( _websocket_input_adapter, create_websocket_input_adapter ); REGISTER_OUTPUT_ADAPTER( _websocket_output_adapter, create_websocket_output_adapter ); REGISTER_OUTPUT_ADAPTER( _websocket_header_update_adapter, create_websocket_header_update_adapter); +REGISTER_OUTPUT_ADAPTER( _websocket_connection_request_adapter, create_websocket_connection_request_adapter); + static PyModuleDef _websocketadapterimpl_module = { PyModuleDef_HEAD_INIT, diff --git a/csp/adapters/dynamic_adapter_utils.py b/csp/adapters/dynamic_adapter_utils.py new file mode 100644 index 000000000..1e8d1ae3a --- /dev/null +++ b/csp/adapters/dynamic_adapter_utils.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, NonNegativeInt + + +class AdapterInfo(BaseModel): + caller_id: NonNegativeInt + is_subscribe: bool diff --git a/csp/adapters/websocket.py b/csp/adapters/websocket.py index 2ba305c9a..5fdeb049e 100644 --- a/csp/adapters/websocket.py +++ b/csp/adapters/websocket.py @@ -21,8 +21,10 @@ from csp.impl.wiring.delayed_node import DelayedNodeWrapperDef from csp.lib import _websocketadapterimpl -from .websocket_types import WebsocketHeaderUpdate +from .dynamic_adapter_utils import AdapterInfo +from .websocket_types import ActionType, ConnectionRequest, WebsocketHeaderUpdate, WebsocketStatus # noqa +# InternalConnectionRequest, _ = ( BytesMessageProtoMapper, DateTimeType, @@ -59,6 +61,12 @@ def diff_dict(old, new): return d +def _sanitize_port(uri: str, port): + if port: + return str(port) + return "443" if uri.startswith("wss") else "80" + + class TableManager: def __init__(self, tables, delta_updates): self._tables = tables @@ -237,7 +245,7 @@ def on_close(self): self._manager.unsubscribe(self) def on_message(self, message): - logging.info("got message %r", message) + logging.warning("got message %r", message) # TODO Ignore for now # parsed = rapidjson.loads(message) @@ -387,12 +395,34 @@ def _instantiate(self): _launch_application(self._port, manager, csp.const("stub")) +# Maybe, we can have the Adapter manager have all the connections +# If we get a new connection request, we include that adapter for the +# subscriptions. When we pop it, we remove it. +# Then, each edge will effectively be independent. +# Maybe. have each websocket push to a shared queue, then from there we +# pass it along to all edges ("input adapters") that are subscribed to it + +# Ok, maybe, let's keep it at just 1 subscribe and send call. +# However, we can subscribe to the send and subscribe calls separately. +# We just have to keep track of the Endpoints we have, and + + class WebsocketAdapterManager: + """ + Can subscribe dynamically via ts[List[ConnectionRequest]] + + We use a ts[List[ConnectionRequest]] to allow users to submit a batch of conneciton requests in + a single engine cycle. + """ + def __init__( self, - uri: str, + uri: Optional[str] = None, reconnect_interval: timedelta = timedelta(seconds=2), - headers: Dict[str, str] = None, + headers: Optional[Dict[str, str]] = None, + dynamic: bool = False, + connection_request: Optional[ConnectionRequest] = None, + num_threads: int = 1, ): """ uri: str @@ -401,26 +431,83 @@ def __init__( time interval to wait before trying to reconnect (must be >= 1 second) headers: Dict[str, str] = None headers to apply to the request during the handshake + dynamic: bool = False + Whether we accept dynamically altering the connections via ConnectionRequest objects. + num_threads: int = 1 + Determines number of threads to allocate for running the websocket endpoints. + Defaults to 1 to avoid thread switching """ + + self._properties = dict(dynamic=dynamic, num_threads=num_threads) + # Enumerating for clarity + if connection_request is not None and uri is not None: + raise ValueError("'connection_request' cannot be set along with 'uri'") + + # Exactly 1 of connection_request and uri is None + if connection_request is not None or uri is not None: + if connection_request is None: + connection_request = ConnectionRequest( + uri=uri, reconnect_interval=reconnect_interval, headers=headers or {} + ) + self._properties.update(self._get_properties(connection_request)) + + # This is a counter that will be used to identify every function call + # We keep track of the subscribes and sends separately + self._subscribe_call_id = 0 + self._send_call_id = 0 + + # This maps types to their wrapper structs + self._wrapper_struct_dict = {} + + @property + def _dynamic(self): + return self._properties.get("dynamic", False) + + def _get_properties(self, conn_request: ConnectionRequest) -> dict: + uri = conn_request.uri + reconnect_interval = conn_request.reconnect_interval + assert reconnect_interval >= timedelta(seconds=1) resp = urllib.parse.urlparse(uri) if resp.hostname is None: raise ValueError(f"Failed to parse host from URI: {uri}") - self._properties = dict( + res = dict( host=resp.hostname, # if no port is explicitly present in the uri, the resp.port is None - port=self._sanitize_port(uri, resp.port), + port=_sanitize_port(uri, resp.port), route=resp.path or "/", # resource shouldn't be empty string use_ssl=uri.startswith("wss"), reconnect_interval=reconnect_interval, - headers=headers if headers else {}, + headers=conn_request.headers, + persistent=conn_request.persistent, + action=conn_request.action.name, + on_connect_payload=conn_request.on_connect_payload, + uri=uri, + dynamic=self._dynamic, ) + return res - def _sanitize_port(self, uri: str, port): - if port: - return str(port) - return "443" if uri.startswith("wss") else "80" + def _get_caller_id(self, is_subscribe: bool) -> int: + if is_subscribe: + caller_id = self._subscribe_call_id + self._subscribe_call_id += 1 + else: + caller_id = self._send_call_id + self._send_call_id += 1 + return caller_id + + def get_wrapper_struct(self, ts_type: type): + if (dynamic_type := self._wrapper_struct_dict.get(ts_type)) is None: + # I want to preserve type information + # Not sure a better way to do this + class CustomWrapperStruct(csp.Struct): + msg: ts_type # noqa + uri: str + + dynamic_type = CustomWrapperStruct + self._wrapper_struct_dict[ts_type] = dynamic_type + return dynamic_type def subscribe( self, @@ -429,7 +516,27 @@ def subscribe( field_map: Union[dict, str] = None, meta_field_map: dict = None, push_mode: csp.PushMode = csp.PushMode.NON_COLLAPSING, + connection_request: Optional[ts[List[ConnectionRequest]]] = None, ): + """If dynamic is True, this will tick a custom WrapperStruct, + with 'msg' as the correct type of the message. + And 'uri' that specifies the 'uri' the message comes from. + + Otherwise, returns just message. + + ts_type should be original type!! The tuple wrapping happens + automatically + """ + caller_id = self._get_caller_id(is_subscribe=True) + # Gives validation, more to start defining a common interface + adapter_props = AdapterInfo(caller_id=caller_id, is_subscribe=True).model_dump() + connection_request = csp.null_ts(List[ConnectionRequest]) if connection_request is None else connection_request + request_dict = csp.apply( + connection_request, lambda conn_reqs: [self._get_properties(conn_req) for conn_req in conn_reqs], list + ) + # Output adapter to handle connection requests + _websocket_connection_request_adapter_def(self, request_dict, adapter_props) + field_map = field_map or {} meta_field_map = meta_field_map or {} if isinstance(field_map, str): @@ -442,12 +549,26 @@ def subscribe( properties["field_map"] = field_map properties["meta_field_map"] = meta_field_map + properties.update(adapter_props) + # We wrap the message in a struct to note the url it comes from + if self._dynamic: + ts_type = self.get_wrapper_struct(ts_type=ts_type) return _websocket_input_adapter_def(self, ts_type, properties, push_mode=push_mode) - def send(self, x: ts["T"]): - return _websocket_output_adapter_def(self, x) + def send(self, x: ts["T"], connection_request: Optional[ts[List[ConnectionRequest]]] = None): + caller_id = self._get_caller_id(is_subscribe=False) + # Gives validation, more to start defining a common interface + adapter_props = AdapterInfo(caller_id=caller_id, is_subscribe=False).model_dump() + connection_request = csp.null_ts(List[ConnectionRequest]) if connection_request is None else connection_request + request_dict = csp.apply( + connection_request, lambda conn_reqs: [self._get_properties(conn_req) for conn_req in conn_reqs], list + ) + _websocket_connection_request_adapter_def(self, request_dict, adapter_props) + return _websocket_output_adapter_def(self, x, adapter_props) def update_headers(self, x: ts[List[WebsocketHeaderUpdate]]): + if self._dynamic: + raise ValueError("If dynamic, cannot call update_headers") return _websocket_header_update_adapter_def(self, x) def status(self, push_mode=csp.PushMode.NON_COLLAPSING): @@ -456,6 +577,7 @@ def status(self, push_mode=csp.PushMode.NON_COLLAPSING): def _create(self, engine, memo): """method needs to return the wrapped c++ adapter manager""" + self._properties.update({"subscribe_calls": self._subscribe_call_id, "send_calls": self._send_call_id}) return _websocketadapterimpl._websocket_adapter_manager(engine, self._properties) @@ -473,6 +595,7 @@ def _create(self, engine, memo): _websocketadapterimpl._websocket_output_adapter, WebsocketAdapterManager, input=ts["T"], + properties=dict, ) _websocket_header_update_adapter_def = output_adapter_def( @@ -481,3 +604,11 @@ def _create(self, engine, memo): WebsocketAdapterManager, input=ts[List[WebsocketHeaderUpdate]], ) + +_websocket_connection_request_adapter_def = output_adapter_def( + "websocket_connection_request_adapter", + _websocketadapterimpl._websocket_connection_request_adapter, + WebsocketAdapterManager, + input=ts[list], # needed, List[dict] didn't work on c++ level + properties=dict, +) diff --git a/csp/adapters/websocket_types.py b/csp/adapters/websocket_types.py index 710610501..314d2d8a9 100644 --- a/csp/adapters/websocket_types.py +++ b/csp/adapters/websocket_types.py @@ -1,3 +1,6 @@ +from datetime import timedelta +from typing import Dict + from csp.impl.enum import Enum from csp.impl.struct import Struct @@ -12,6 +15,22 @@ class WebsocketStatus(Enum): MESSAGE_SEND_FAIL = 4 +class ActionType(Enum): + CONNECT = 0 + DISCONNECT = 1 + PING = 2 + + class WebsocketHeaderUpdate(Struct): key: str value: str + + +class ConnectionRequest(Struct): + uri: str + action: ActionType = ActionType.CONNECT # Connect, Disconnect, Ping, etc + # Whetehr we maintain the connection + persistent: bool = True # Only relevant for Connect requests + reconnect_interval: timedelta = timedelta(seconds=2) + on_connect_payload: str = "" # message to send on connect + headers: Dict[str, str] = {} diff --git a/csp/tests/adapters/test_websocket.py b/csp/tests/adapters/test_websocket.py index 41aeb3311..fa1463d70 100644 --- a/csp/tests/adapters/test_websocket.py +++ b/csp/tests/adapters/test_websocket.py @@ -1,9 +1,10 @@ import os +import pytest import pytz import threading -import unittest +from contextlib import contextmanager from datetime import datetime, timedelta -from typing import List +from typing import List, Optional, Type import csp from csp import ts @@ -13,28 +14,70 @@ import tornado.web import tornado.websocket - from csp.adapters.websocket import JSONTextMessageMapper, RawTextMessageMapper, Status, WebsocketAdapterManager + from csp.adapters.websocket import ( + ActionType, + ConnectionRequest, + JSONTextMessageMapper, + RawTextMessageMapper, + Status, + WebsocketAdapterManager, + WebsocketHeaderUpdate, + WebsocketStatus, + ) class EchoWebsocketHandler(tornado.websocket.WebSocketHandler): def on_message(self, msg): + # Carve-out to allow inspecting the headers + if msg == "header1": + msg = self.request.headers.get(msg, "") return self.write_message(msg) - -@unittest.skipIf(not os.environ.get("CSP_TEST_WEBSOCKET"), "Skipping websocket adapter tests") -class TestWebsocket(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) - cls.app.listen(8000) - cls.io_loop = tornado.ioloop.IOLoop.current() - cls.io_thread = threading.Thread(target=cls.io_loop.start) - cls.io_thread.start() - - @classmethod - def tearDownClass(cls): - cls.io_loop.add_callback(cls.io_loop.stop) - if cls.io_thread: - cls.io_thread.join() + @contextmanager + def create_tornado_server(port: int): + """Base context manager for creating a Tornado server in a thread""" + ready_event = threading.Event() + io_loop = None + app = None + io_thread = None + + def run_io_loop(): + nonlocal io_loop, app + io_loop = tornado.ioloop.IOLoop() + io_loop.make_current() + app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) + app.listen(port) + ready_event.set() + io_loop.start() + + io_thread = threading.Thread(target=run_io_loop) + io_thread.start() + ready_event.wait() + + try: + yield io_loop, app, io_thread + finally: + io_loop.add_callback(io_loop.stop) + if io_thread: + io_thread.join(timeout=5) + if io_thread.is_alive(): + raise RuntimeError("IOLoop failed to stop") + + @contextmanager + def tornado_server(port: int = 8001): + """Simplified context manager that uses the base implementation""" + with create_tornado_server(port) as (_io_loop, _app, _io_thread): + yield + + +@pytest.mark.skipif(os.environ.get("CSP_TEST_WEBSOCKET") is None, reason="'CSP_TEST_WEBSOCKET' env variable is not set") +class TestWebsocket: + @pytest.fixture(scope="class", autouse=True) + def setup_tornado(self, request): + with create_tornado_server(8000) as (io_loop, app, io_thread): + request.cls.io_loop = io_loop + request.cls.app = app + request.cls.io_thread = io_thread + yield def test_send_recv_msg(self): @csp.node @@ -55,6 +98,126 @@ def g(): msgs = csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) assert msgs["recv"][0][1] == "Hello, World!" + def test_headers(self): + @csp.graph + def g(dynamic: bool): + if dynamic: + ws = WebsocketAdapterManager(dynamic=True) + # Connect with header + conn_request1 = csp.const( + [ + ConnectionRequest( + uri="ws://localhost:8000/", on_connect_payload="header1", headers={"header1": "value1"} + ) + ] + ) + # Disconnect to shutdown endpoint + conn_request2 = csp.const( + [ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.DISCONNECT)], + delay=timedelta(milliseconds=100), + ) + # Reconnect to open endpoint with new headers + conn_request3 = csp.const( + [ + ConnectionRequest( + uri="ws://localhost:8000/", on_connect_payload="header1", headers={"header1": "value2"} + ) + ], + delay=timedelta(milliseconds=150), + ) + conn_request3 = csp.const( + [ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.PING)], + delay=timedelta(milliseconds=151), + ) + conn_request4 = csp.const( + [ + ConnectionRequest( + uri="ws://localhost:8000/", on_connect_payload="header1", headers={"header1": "value2"} + ) + ], + delay=timedelta(milliseconds=200), + ) + conn_req = csp.flatten([conn_request1, conn_request2, conn_request3, conn_request4]) + status = ws.status() + csp.add_graph_output("status", status) + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_req) + csp.add_graph_output("recv", recv) + stop = csp.filter(csp.count(recv) == 2, recv) + csp.stop_engine(stop) + + if not dynamic: + ws = WebsocketAdapterManager("ws://localhost:8000/", headers={"header1": "value1"}) + status = ws.status() + send_msg = csp.sample(status, csp.const("header1")) + to_send = csp.merge(send_msg, csp.const("header1", delay=timedelta(milliseconds=100))) + ws.send(to_send) + recv = ws.subscribe(str, RawTextMessageMapper()) + + header_update = csp.const( + [WebsocketHeaderUpdate(key="header1", value="value2")], delay=timedelta(milliseconds=50) + ) + # Doesn' tick out since we don't disconnect + ws.update_headers(header_update) + status = ws.status() + csp.add_graph_output("status", status) + + csp.add_graph_output("recv", recv) + csp.stop_engine(recv) + + msgs = csp.run(g, dynamic=False, starttime=datetime.now(pytz.UTC), realtime=True) + assert msgs["recv"][0][1] == "value1" + assert len(msgs["status"]) == 1 + assert msgs["status"][0][1].status_code == WebsocketStatus.ACTIVE.value + + msgs = csp.run(g, dynamic=True, starttime=datetime.now(pytz.UTC), realtime=True) + assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + assert msgs["recv"][1][1].uri == "ws://localhost:8000/" + assert msgs["recv"][0][1].msg == "value1" + assert msgs["recv"][1][1].msg == "value2" + + assert len(msgs["status"]) == 3 + assert msgs["status"][0][1].status_code == WebsocketStatus.ACTIVE.value + assert msgs["status"][1][1].status_code == WebsocketStatus.CLOSED.value + assert msgs["status"][2][1].status_code == WebsocketStatus.ACTIVE.value + + @pytest.mark.parametrize("send_payload_subscribe", [True, False]) + def test_send_recv_json_dynamic_on_connect_payload(self, send_payload_subscribe): + class MsgStruct(csp.Struct): + a: int + b: str + + @csp.graph + def g(): + ws = WebsocketAdapterManager(dynamic=True) + conn_request = ConnectionRequest( + uri="ws://localhost:8000/", + action=ActionType.CONNECT, + on_connect_payload=MsgStruct(a=1234, b="im a string").to_json(), + ) + if not send_payload_subscribe: + # We send payload via the dummy send function + # The 'on_connect_payload sends the result + ws.send(csp.null_ts(object), connection_request=csp.const([conn_request])) + subscribe_connection_request = ( + [ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.CONNECT)] + if not send_payload_subscribe + else [conn_request] + ) + recv = ws.subscribe( + MsgStruct, JSONTextMessageMapper(), connection_request=csp.const(subscribe_connection_request) + ) + + csp.add_graph_output("recv", recv) + csp.stop_engine(recv) + + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) + obj = msgs["recv"][0][1] + assert obj.uri == "ws://localhost:8000/" + true_obj = obj.msg + assert isinstance(true_obj, MsgStruct) + assert true_obj.a == 1234 + assert true_obj.b == "im a string" + def test_send_recv_json(self): class MsgStruct(csp.Struct): a: int @@ -119,37 +282,256 @@ def g(n: int): assert len(msgs["recv"]) == n assert msgs["recv"][0][1] != msgs["recv"][-1][1] - def test_unkown_host_graceful_shutdown(self): + def test_send_multiple_and_recv_msgs_dynamic(self): @csp.graph def g(): - ws = WebsocketAdapterManager("wss://localhost/") - assert ws._properties["port"] == "443" - csp.stop_engine(ws.status()) + ws = WebsocketAdapterManager(dynamic=True) + conn_request = csp.const( + [ + ConnectionRequest( + uri="ws://localhost:8000/", + action=ActionType.CONNECT, + ) + ] + ) + val = csp.curve(int, [(timedelta(milliseconds=50), 0), (timedelta(milliseconds=500), 1)]) + hello = csp.apply(val, lambda x: f"hi world{x}", str) + delayed_conn_req = csp.delay(conn_request, delay=timedelta(milliseconds=100)) + + # We connect immediately and send out the hello message + ws.send(hello, connection_request=conn_request) + + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=delayed_conn_req) + # This call connects first + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request) + + merged = csp.flatten([recv, recv2]) + csp.add_graph_output("recv", merged.msg) + + stop = csp.filter(csp.count(merged) == 3, merged) + csp.stop_engine(stop) + + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=1), realtime=True) + assert len(msgs["recv"]) == 3 + # the first message sent out, only the second subscribe call picks this up + assert msgs["recv"][0][1] == "hi world0" + # Both the subscribe calls receive this message + assert msgs["recv"][1][1] == "hi world1" + assert msgs["recv"][2][1] == "hi world1" + + @pytest.mark.parametrize("reconnect_immeditately", [False, True]) + def test_dynamic_disconnect_connect_pruned_subscribe(self, reconnect_immeditately): + @csp.node + def prevent_prune(objs: ts[str]): + if csp.ticked(objs): + # Does nothing but makes sure it's not pruned + ... - csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) + @csp.graph + def g(): + ws = WebsocketAdapterManager(dynamic=True) + + if reconnect_immeditately: + disconnect_reqs = [ + ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.DISCONNECT), + ConnectionRequest(uri="ws://localhost:8000/"), + ] + else: + disconnect_reqs = [ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.DISCONNECT)] + conn_request = csp.curve( + List[ConnectionRequest], + [ + (timedelta(), [ConnectionRequest(uri="ws://localhost:8000/")]), + ( + timedelta(milliseconds=100), + disconnect_reqs, + ), + ( + timedelta(milliseconds=350), + [ + ConnectionRequest( + uri="ws://localhost:8000/", + headers={"dummy_key": "dummy_value"}, + ), + ], + ), + ], + ) + const_conn_request = csp.const([ConnectionRequest(uri="ws://localhost:8000/")]) + val = csp.curve(int, [(timedelta(milliseconds=100, microseconds=1), 0), (timedelta(milliseconds=500), 1)]) + hello = csp.apply(val, lambda x: f"hi world{x}", str) + + # We connect immediately and send out the hello message + ws.send(hello, connection_request=const_conn_request) + + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request) + # This gets pruned by csp + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request) + recv3 = ws.subscribe(str, RawTextMessageMapper(), connection_request=const_conn_request) + + no_persist_conn = ConnectionRequest( + uri="ws://localhost:8000/", persistent=False, on_connect_payload="hi non-persistent world!" + ) + recv4 = ws.subscribe( + str, + RawTextMessageMapper(), + connection_request=csp.const([no_persist_conn], delay=timedelta(milliseconds=250)), + ) - def test_send_recv_burst_json(self): - class MsgStruct(csp.Struct): - a: int - b: str + csp.add_graph_output("recv", recv) + csp.add_graph_output("recv3", recv3) + csp.add_graph_output("recv4", recv4) + end = csp.filter(csp.count(recv3) == 3, recv3) + csp.stop_engine(end) + + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=1), realtime=True) + # Did not persist, so did not receive any messages + assert len(msgs["recv4"]) == 0 + # Only the second message is received, since we disonnect before the first one is sent + if not reconnect_immeditately: + assert len(msgs["recv"]) == 1 + assert msgs["recv"][0][1].msg == "hi world1" + assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + else: + assert len(msgs["recv"]) == 3 + assert msgs["recv"][0][1].msg == "hi world0" + assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + assert msgs["recv"][1][1].msg == "hi non-persistent world!" + assert msgs["recv"][1][1].uri == "ws://localhost:8000/" + assert msgs["recv"][2][1].msg == "hi world1" + assert msgs["recv"][2][1].uri == "ws://localhost:8000/" + + # This subscribe call received all the messages + assert len(msgs["recv3"]) == 3 + assert msgs["recv3"][0][1].msg == "hi world0" + assert msgs["recv3"][0][1].uri == "ws://localhost:8000/" + assert msgs["recv3"][1][1].msg == "hi non-persistent world!" + assert msgs["recv3"][1][1].uri == "ws://localhost:8000/" + assert msgs["recv3"][2][1].msg == "hi world1" + assert msgs["recv3"][2][1].uri == "ws://localhost:8000/" + + def test_dynamic_pruned_subscribe(self): + @csp.graph + def g(): + ws = WebsocketAdapterManager(dynamic=True) + conn_request = csp.const( + [ + ConnectionRequest( + uri="ws://localhost:8000/", + action=ActionType.CONNECT, + ) + ] + ) + val = csp.curve(int, [(timedelta(milliseconds=50), 0), (timedelta(milliseconds=500), 1)]) + hello = csp.apply(val, lambda x: f"hi world{x}", str) + delayed_conn_req = csp.delay(conn_request, delay=timedelta(milliseconds=100)) + + # We connect immediately and send out the hello message + ws.send(hello, connection_request=conn_request) + + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=delayed_conn_req) + # This gets pruned by csp + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request) + + csp.add_graph_output("recv", recv) + csp.stop_engine(recv) + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=2), realtime=True) + assert len(msgs["recv"]) == 1 + # Only the second message is received + assert msgs["recv"][0][1].msg == "hi world1" + assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + + def test_dynamic_multiple_subscribers(self): @csp.node - def send_msg_on_open(status: ts[Status]) -> ts[str]: + def send_on_status(status: ts[Status], uri: str, val: str) -> ts[str]: if csp.ticked(status): - return MsgStruct(a=1234, b="im a string").to_json() + if uri in status.msg and status.status_code == WebsocketStatus.ACTIVE.value: + return val + + with tornado_server(): + # We do this to only spawn the tornado server once for both options + @csp.graph + def g(use_on_connect_payload: bool): + ws = WebsocketAdapterManager(dynamic=True) + if use_on_connect_payload: + conn_request1 = csp.const( + [ConnectionRequest(uri="ws://localhost:8000/", on_connect_payload="hey world from 8000")] + ) + conn_request2 = csp.const( + [ConnectionRequest(uri="ws://localhost:8001/", on_connect_payload="hey world from 8001")] + ) + else: + conn_request1 = csp.const([ConnectionRequest(uri="ws://localhost:8000/")]) + conn_request2 = csp.const([ConnectionRequest(uri="ws://localhost:8001/")]) + status = ws.status() + to_send = send_on_status(status, "ws://localhost:8000/", "hey world from 8000") + to_send2 = send_on_status(status, "ws://localhost:8001/", "hey world from 8001") + ws.send(to_send, connection_request=conn_request1) + ws.send(to_send2, connection_request=conn_request2) + + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request1) + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request2) + + csp.add_graph_output("recv", recv) + csp.add_graph_output("recv2", recv2) + + merged = csp.flatten([recv, recv2]) + stop = csp.filter(csp.count(merged) == 2, merged) + csp.stop_engine(stop) + + for use_on_connect_payload in [True, False]: + msgs = csp.run( + g, + use_on_connect_payload, + starttime=datetime.now(pytz.UTC), + endtime=timedelta(seconds=5), + realtime=True, + ) + assert len(msgs["recv"]) == 1 + assert msgs["recv"][0][1].msg == "hey world from 8000" + assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + assert len(msgs["recv2"]) == 1 + assert msgs["recv2"][0][1].msg == "hey world from 8001" + assert msgs["recv2"][0][1].uri == "ws://localhost:8001/" + + @pytest.mark.parametrize("dynamic", [False, True]) + def test_send_recv_burst_json(self, dynamic): + class MsgStruct(csp.Struct): + a: int + b: str @csp.node - def my_edge_that_handles_burst(objs: ts[List[MsgStruct]]) -> ts[bool]: + def my_edge_that_handles_burst(objs: ts[List[MsgStruct]]): if csp.ticked(objs): - return True + # Does nothing but makes sure it's not pruned + ... @csp.graph def g(): - ws = WebsocketAdapterManager("ws://localhost:8000/") - status = ws.status() - ws.send(send_msg_on_open(status)) - recv = ws.subscribe(MsgStruct, JSONTextMessageMapper(), push_mode=csp.PushMode.BURST) - _ = my_edge_that_handles_burst(recv) + if dynamic: + ws = WebsocketAdapterManager(dynamic=True) + wrapped_recv = ws.subscribe( + MsgStruct, + JSONTextMessageMapper(), + push_mode=csp.PushMode.BURST, + connection_request=csp.const( + [ + ConnectionRequest( + uri="ws://localhost:8000/", + on_connect_payload=MsgStruct(a=1234, b="im a string").to_json(), + ) + ] + ), + ) + recv = csp.apply(wrapped_recv, lambda vals: [v.msg for v in vals], List[MsgStruct]) + else: + ws = WebsocketAdapterManager("ws://localhost:8000/") + status = ws.status() + ws.send(csp.apply(status, lambda _x: MsgStruct(a=1234, b="im a string").to_json(), str)) + recv = ws.subscribe(MsgStruct, JSONTextMessageMapper(), push_mode=csp.PushMode.BURST) + + my_edge_that_handles_burst(recv) csp.add_graph_output("recv", recv) csp.stop_engine(recv) @@ -159,3 +541,28 @@ def g(): innerObj = obj[0] assert innerObj.a == 1234 assert innerObj.b == "im a string" + + def test_unkown_host_graceful_shutdown(self): + @csp.graph + def g(): + ws = WebsocketAdapterManager("wss://localhost/") + # We need this since without any input or output + # adapters, the websocket connection is not actually made. + ws.send(csp.null_ts(str)) + assert ws._properties["port"] == "443" + csp.stop_engine(ws.status()) + + csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) + + def test_unkown_host_graceful_shutdown_slow(self): + @csp.graph + def g(): + ws = WebsocketAdapterManager("wss://localhost/") + # We need this since without any input or output + # adapters, the websocket connection is not actually made. + ws.send(csp.null_ts(str)) + assert ws._properties["port"] == "443" + stop_flag = csp.filter(csp.count(ws.status()) == 2, ws.status()) + csp.stop_engine(stop_flag) + + csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) From d70dce8fa84d766f5af8469701a6b3c7bd02da83 Mon Sep 17 00:00:00 2001 From: Nijat K <nijat.khanbabayev@gmail.com> Date: Mon, 25 Nov 2024 05:06:30 -0500 Subject: [PATCH 2/8] Allow setting websockets to use binary Signed-off-by: Nijat K <nijat.khanbabayev@gmail.com> --- cpp/csp/adapters/websocket/WebsocketEndpoint.h | 4 ++++ csp/adapters/websocket.py | 6 +++++- csp/tests/adapters/test_websocket.py | 8 ++++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/cpp/csp/adapters/websocket/WebsocketEndpoint.h b/cpp/csp/adapters/websocket/WebsocketEndpoint.h index eea9fed2b..502d3eb59 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpoint.h +++ b/cpp/csp/adapters/websocket/WebsocketEndpoint.h @@ -232,6 +232,8 @@ class WebsocketSessionNoTLS final: public WebsocketSession<WebsocketSessionNoTLS self->m_on_fail(ec.message()); return; } + if( self->m_properties->get<bool>("binary") ) + self->m_ws.binary( true ); self->m_on_open(); self->m_ws.async_read( self->m_buffer, @@ -339,6 +341,8 @@ class WebsocketSessionTLS final: public WebsocketSession<WebsocketSessionTLS> { self->m_on_fail(ec.message()); return; } + if( self->m_properties->get<bool>("binary") ) + self->m_ws.binary( true ); self->m_on_open(); self->m_ws.async_read( self->m_buffer, diff --git a/csp/adapters/websocket.py b/csp/adapters/websocket.py index 5fdeb049e..3a595aeef 100644 --- a/csp/adapters/websocket.py +++ b/csp/adapters/websocket.py @@ -423,6 +423,7 @@ def __init__( dynamic: bool = False, connection_request: Optional[ConnectionRequest] = None, num_threads: int = 1, + binary: bool = False, ): """ uri: str @@ -436,9 +437,11 @@ def __init__( num_threads: int = 1 Determines number of threads to allocate for running the websocket endpoints. Defaults to 1 to avoid thread switching + binary: bool = False + Whether to send/receive text or binary data """ - self._properties = dict(dynamic=dynamic, num_threads=num_threads) + self._properties = dict(dynamic=dynamic, num_threads=num_threads, binary=binary) # Enumerating for clarity if connection_request is not None and uri is not None: raise ValueError("'connection_request' cannot be set along with 'uri'") @@ -485,6 +488,7 @@ def _get_properties(self, conn_request: ConnectionRequest) -> dict: on_connect_payload=conn_request.on_connect_payload, uri=uri, dynamic=self._dynamic, + binary=self._properties.get("binary", False), ) return res diff --git a/csp/tests/adapters/test_websocket.py b/csp/tests/adapters/test_websocket.py index fa1463d70..6d23d5c2b 100644 --- a/csp/tests/adapters/test_websocket.py +++ b/csp/tests/adapters/test_websocket.py @@ -30,6 +30,9 @@ def on_message(self, msg): # Carve-out to allow inspecting the headers if msg == "header1": msg = self.request.headers.get(msg, "") + elif not isinstance(msg, str) and msg.decode("utf-8") == "header1": + # Need this for bytes + msg = self.request.headers.get("header1", "") return self.write_message(msg) @contextmanager @@ -98,11 +101,12 @@ def g(): msgs = csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) assert msgs["recv"][0][1] == "Hello, World!" - def test_headers(self): + @pytest.mark.parametrize("binary", [False, True]) + def test_headers(self, binary): @csp.graph def g(dynamic: bool): if dynamic: - ws = WebsocketAdapterManager(dynamic=True) + ws = WebsocketAdapterManager(dynamic=True, binary=binary) # Connect with header conn_request1 = csp.const( [ From 5664f567a70b991ab19a2a46cac3d9d80dac5f42 Mon Sep 17 00:00:00 2001 From: Nijat K <nijat.khanbabayev@gmail.com> Date: Mon, 25 Nov 2024 16:18:31 -0500 Subject: [PATCH 3/8] Include websocket tests by default Signed-off-by: Nijat K <nijat.khanbabayev@gmail.com> --- Makefile | 4 +- cpp/csp/adapters/websocket/CMakeLists.txt | 2 - .../adapters/websocket/ClientAdapterManager.h | 1 - .../websocket/WebsocketClientTypes.cpp | 13 - .../adapters/websocket/WebsocketClientTypes.h | 26 -- .../websocket/WebsocketEndpointManager.cpp | 31 ++- .../websocket/WebsocketEndpointManager.h | 20 +- cpp/csp/engine/StatusAdapter.h | 2 +- cpp/csp/engine/Struct.h | 99 +++++-- csp/adapters/websocket.py | 9 +- csp/tests/adapters/test_websocket.py | 257 +++++++++--------- pyproject.toml | 1 + 12 files changed, 264 insertions(+), 201 deletions(-) delete mode 100644 cpp/csp/adapters/websocket/WebsocketClientTypes.cpp delete mode 100644 cpp/csp/adapters/websocket/WebsocketClientTypes.h diff --git a/Makefile b/Makefile index 32827d4eb..52bda34c3 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ build-debug: ## build the library ( DEBUG ) - May need a make clean when switch SKBUILD_CONFIGURE_OPTIONS="" DEBUG=1 python setup.py build build_ext --inplace build-conda: ## build the library in Conda - python setup.py build build_ext --csp-no-vcpkg --inplace + CSP_USE_CCACHE=0 python setup.py build build_ext --csp-no-vcpkg --inplace install: ## install library python -m pip install . @@ -83,7 +83,7 @@ checks: check TEST_ARGS := test-py: ## Clean and Make unit tests - python -m pytest -v csp/tests --junitxml=junit.xml $(TEST_ARGS) + python -m pytest -vv -s csp/tests --junitxml=junit.xml $(TEST_ARGS) test-cpp: ## Make C++ unit tests ifneq ($(OS),Windows_NT) diff --git a/cpp/csp/adapters/websocket/CMakeLists.txt b/cpp/csp/adapters/websocket/CMakeLists.txt index bb85d20da..6bd7d7ae3 100644 --- a/cpp/csp/adapters/websocket/CMakeLists.txt +++ b/cpp/csp/adapters/websocket/CMakeLists.txt @@ -1,7 +1,6 @@ csp_autogen( csp.adapters.websocket_types websocket_types WEBSOCKET_HEADER WEBSOCKET_SOURCE ) set(WS_CLIENT_HEADER_FILES - WebsocketClientTypes.h ClientAdapterManager.h ClientInputAdapter.h ClientOutputAdapter.h @@ -13,7 +12,6 @@ set(WS_CLIENT_HEADER_FILES ) set(WS_CLIENT_SOURCE_FILES - WebsocketClientTypes.cpp ClientAdapterManager.cpp ClientInputAdapter.cpp ClientOutputAdapter.cpp diff --git a/cpp/csp/adapters/websocket/ClientAdapterManager.h b/cpp/csp/adapters/websocket/ClientAdapterManager.h index e530ee10f..101e05f98 100644 --- a/cpp/csp/adapters/websocket/ClientAdapterManager.h +++ b/cpp/csp/adapters/websocket/ClientAdapterManager.h @@ -3,7 +3,6 @@ #include <csp/adapters/websocket/WebsocketEndpoint.h> #include <csp/adapters/websocket/WebsocketEndpointManager.h> -#include <csp/adapters/websocket/WebsocketClientTypes.h> #include <csp/adapters/websocket/ClientInputAdapter.h> #include <csp/adapters/websocket/ClientHeaderUpdateAdapter.h> #include <csp/core/Enum.h> diff --git a/cpp/csp/adapters/websocket/WebsocketClientTypes.cpp b/cpp/csp/adapters/websocket/WebsocketClientTypes.cpp deleted file mode 100644 index ac4492520..000000000 --- a/cpp/csp/adapters/websocket/WebsocketClientTypes.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "WebsocketClientTypes.h" - -namespace csp { - -INIT_CSP_ENUM( adapters::websocket::ClientStatusType, - "ACTIVE", - "GENERIC_ERROR", - "CONNECTION_FAILED", - "CLOSED", - "MESSAGE_SEND_FAIL", -); - -} \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/WebsocketClientTypes.h b/cpp/csp/adapters/websocket/WebsocketClientTypes.h deleted file mode 100644 index 6f1d255f8..000000000 --- a/cpp/csp/adapters/websocket/WebsocketClientTypes.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include "csp/core/Enum.h" // or whatever the correct path is - -namespace csp::adapters::websocket { - -struct WebsocketClientStatusTypeTraits -{ - enum _enum : unsigned char - { - ACTIVE = 0, - GENERIC_ERROR = 1, - CONNECTION_FAILED = 2, - CLOSED = 3, - MESSAGE_SEND_FAIL = 4, - - NUM_TYPES - }; - -protected: - _enum m_value; -}; - -using ClientStatusType = Enum<WebsocketClientStatusTypeTraits>; - -} \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp index b960af825..6df8c8732 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp @@ -1,5 +1,16 @@ #include <csp/adapters/websocket/WebsocketEndpointManager.h> +namespace csp { + +INIT_CSP_ENUM( adapters::websocket::ClientStatusType, + "ACTIVE", + "GENERIC_ERROR", + "CONNECTION_FAILED", + "CLOSED", + "MESSAGE_SEND_FAIL", +); + +} namespace csp::adapters::websocket { WebsocketEndpointManager::WebsocketEndpointManager( ClientAdapterManager* mgr, const Dictionary & properties, Engine* engine ) @@ -96,7 +107,8 @@ void WebsocketEndpointManager::shutdownEndpoint(const std::string& endpoint_id) m_endpoints.erase(endpoint_it); std::stringstream ss; ss << "No more connections for endpoint={" << endpoint_id << "} Shutting down..."; - m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::CLOSED, ss.str()); + std::string msg = ss.str(); + m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::CLOSED, msg); } void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id, @@ -135,9 +147,10 @@ void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id, // should only happen if persist is False if ( !payload.empty() ) endpoint -> send(payload); - - m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::ACTIVE, - "Connected successfully for endpoint={" + endpoint_id +"}"); + std::stringstream ss; + ss << "Connected successfully for endpoint={" << endpoint_id << "}"; + std::string msg = ss.str(); + m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::ACTIVE, msg); // We remove the caller id, if it was the only one, then we shut down the endpoint if( !persist ) removeEndpointForCallerId(endpoint_id, is_consumer, validated_id); @@ -170,8 +183,9 @@ void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id, stored_endpoint -> setOnSendFail( [ this, endpoint_id ]( const std::string& s ) { std::stringstream ss; - ss << "Error: " << s << " for " << endpoint_id; - m_mgr -> pushStatus( StatusLevel::ERROR, ClientStatusType::MESSAGE_SEND_FAIL, ss.str() ); + ss << "Error: " << s << " for endpoint={" << endpoint_id << "}"; + std::string msg = ss.str(); + m_mgr -> pushStatus( StatusLevel::ERROR, ClientStatusType::MESSAGE_SEND_FAIL, msg ); } ); stored_endpoint -> run(); @@ -214,10 +228,11 @@ void WebsocketEndpointManager::handleEndpointFailure(const std::string& endpoint std::stringstream ss; ss << "Connection Failure for endpoint={" << endpoint_id << "} Due to: " << reason; + std::string msg = ss.str(); if ( status_type == ClientStatusType::CLOSED || status_type == ClientStatusType::ACTIVE ) - m_mgr -> pushStatus(StatusLevel::INFO, status_type, ss.str()); + m_mgr -> pushStatus(StatusLevel::INFO, status_type, msg); else{ - m_mgr -> pushStatus(StatusLevel::ERROR, status_type, ss.str()); + m_mgr -> pushStatus(StatusLevel::ERROR, status_type, msg); } }; diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.h b/cpp/csp/adapters/websocket/WebsocketEndpointManager.h index 736413e03..9007251ed 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpointManager.h +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.h @@ -2,7 +2,6 @@ #define WEBSOCKET_ENDPOINT_MANAGER_H #include <boost/asio.hpp> -#include <csp/adapters/websocket/WebsocketClientTypes.h> #include <csp/adapters/websocket/WebsocketEndpoint.h> #include <csp/adapters/websocket/ClientAdapterManager.h> #include <csp/adapters/websocket/ClientInputAdapter.h> @@ -70,6 +69,25 @@ struct EndpointCallbacks { std::function<void(const std::string&, void*, size_t)> onMessage; }; +struct WebsocketClientStatusTypeTraits +{ + enum _enum : unsigned char + { + ACTIVE = 0, + GENERIC_ERROR = 1, + CONNECTION_FAILED = 2, + CLOSED = 3, + MESSAGE_SEND_FAIL = 4, + + NUM_TYPES + }; + +protected: + _enum m_value; +}; + +using ClientStatusType = Enum<WebsocketClientStatusTypeTraits>; + class WebsocketEndpointManager { public: explicit WebsocketEndpointManager(ClientAdapterManager* mgr, const Dictionary & properties, Engine* engine); diff --git a/cpp/csp/engine/StatusAdapter.h b/cpp/csp/engine/StatusAdapter.h index f243c7c70..8545b7619 100644 --- a/cpp/csp/engine/StatusAdapter.h +++ b/cpp/csp/engine/StatusAdapter.h @@ -44,7 +44,7 @@ class StatusAdapter : public PushInputAdapter m_statusAccess.meta = meta; m_statusAccess.level = meta -> getMetaField<int64_t>( "level", "Status" ); m_statusAccess.statusCode = meta -> getMetaField<int64_t>( "status_code", "Status" ); - m_statusAccess.msg = meta -> getMetaField<std::string>( "msg", "Status" ); + m_statusAccess.msg = meta -> getMetaField<typename csp::StringStructField::CType>( "msg", "Status" ); } void pushStatus( int64_t level, int64_t statusCode, const std::string & msg, PushBatch *batch = nullptr ) diff --git a/cpp/csp/engine/Struct.h b/cpp/csp/engine/Struct.h index e0653e4a3..78d6a9265 100644 --- a/cpp/csp/engine/Struct.h +++ b/cpp/csp/engine/Struct.h @@ -7,6 +7,7 @@ #include <string> #include <vector> #include <unordered_map> +#include <iostream> namespace csp { @@ -668,16 +669,69 @@ class StructMeta : public std::enable_shared_from_this<StructMeta> }; template<typename T> -std::shared_ptr<typename StructField::upcast<T>::type> StructMeta::getMetaField( const char * fieldname, const char * expectedtype ) -{ - auto field_ = field( fieldname ); - if( !field_ ) - CSP_THROW( TypeError, "Struct type " << name() << " missing required field " << fieldname << " for " << expectedtype ); +std::shared_ptr<typename StructField::upcast<T>::type> StructMeta::getMetaField(const char* fieldname, const char* expectedtype) { + std::cout << "\n=== getMetaField Debug ===\n"; + std::cout << "1. Looking for field: " << fieldname << "\n"; + + auto field_ = field(fieldname); + if(!field_) { + std::cout << "2. Field not found!\n"; + CSP_THROW(TypeError, "Struct type " << name() << " missing required field " << fieldname); + } + + std::cout << "2. Field found\n"; + std::cout << "3. Field name from object: " << field_->fieldname() << "\n"; + std::cout << "4. Field type from CspType: " << field_->type()->type() << "\n"; + std::cout << "5. Expected type: " << CspType::Type::fromCType<T>::type << "\n"; + + // Memory layout & pointer checks + const StructField* field_ptr = field_.get(); + std::cout << "6. Field ptr value: " << field_ptr << "\n"; + std::cout << "7. Field use count: " << field_.use_count() << "\n"; + + // Detailed field information + if(field_ptr) { + std::cout << "8. Field metadata:\n"; + std::cout << " - Field offset: " << field_ptr->offset() << "\n"; + std::cout << " - Field size: " << field_ptr->size() << "\n"; + std::cout << " - Field alignment: " << field_ptr->alignment() << "\n"; + std::cout << " - Field mask offset: " << field_ptr->maskOffset() << "\n"; + std::cout << " - Field mask bit: " << static_cast<int>(field_ptr->maskBit()) << "\n"; + + // Type verification + std::cout << "9. Type checks:\n"; + std::cout << " - Original type: " << typeid(field_).name() << "\n"; + std::cout << " - Target type: " << typeid(typename StructField::upcast<T>::type).name() << "\n"; + std::cout << " - Is native: " << field_ptr->isNative() << "\n"; + + // Test various casts + std::cout << "10. Detailed cast tests:\n"; + std::cout << " Base classes:\n"; + std::cout << " - As StructField*: " << (dynamic_cast<const StructField*>(field_ptr) != nullptr) << "\n"; + std::cout << " - As NonNativeStructField*: " << (dynamic_cast<const NonNativeStructField*>(field_ptr) != nullptr) << "\n"; + std::cout << " Non-native implementations:\n"; + std::cout << " - As StringStructField*: " << (dynamic_cast<const StringStructField*>(field_ptr) != nullptr) << "\n"; + std::cout << " - As DialectGenericStructField*: " << (dynamic_cast<const DialectGenericStructField*>(field_ptr) != nullptr) << "\n"; + std::cout << " - As ArrayStructField<std::string>*: " << (dynamic_cast<const ArrayStructField<std::vector<std::string>>*>(field_ptr) != nullptr) << "\n"; + std::cout << " Native field test:\n"; + std::cout << " - As NativeStructField<int64_t>*: " << (dynamic_cast<const NativeStructField<int64_t>*>(field_ptr) != nullptr) << "\n"; +} + + using TargetType = typename StructField::upcast<T>::type; + auto typedfield = std::dynamic_pointer_cast<TargetType>(field_); + std::cout << "11. Final dynamic_cast result: " << (typedfield ? "success" : "failure") << "\n"; - std::shared_ptr<typename StructField::upcast<T>::type> typedfield = std::dynamic_pointer_cast<typename StructField::upcast<T>::type>( field_ ); - if( !typedfield ) - CSP_THROW( TypeError, expectedtype << " - provided struct type " << name() << " expected type " << CspType::Type::fromCType<T>::type << " for field " << fieldname - << " but got type " << field_ -> type() -> type() << " for " << expectedtype ); + if(!typedfield) { + std::cout << "12. FAILED CAST DETAILS:\n"; + std::cout << " - Source type: " << typeid(StructField).name() << "\n"; + std::cout << " - Target type: " << typeid(TargetType).name() << "\n"; + + CSP_THROW(TypeError, expectedtype << " - provided struct type " << name() + << " expected type " << CspType::Type::fromCType<T>::type + << " for field " << fieldname + << " but got type " << field_->type()->type() + << " for " << expectedtype); + } return typedfield; } @@ -773,11 +827,11 @@ class Struct friend class StructMeta; //Note these members are not included on size(), they're stored before "this" ptr ( see operator new / delete ) - struct HiddenData - { - size_t refcount; - std::shared_ptr<const StructMeta> meta; - void * dialectPtr; + struct alignas(8) HiddenData { + alignas(8) size_t refcount; // 8 bytes at 0x0 + alignas(8) std::shared_ptr<const StructMeta> meta; // 16 bytes at 0x8 + alignas(8) void* dialectPtr; // 8 bytes at 0x18 + // Total: 32 bytes }; const HiddenData * hidden() const @@ -785,10 +839,19 @@ class Struct return const_cast<Struct *>( this ) -> hidden(); } - HiddenData * hidden() - { - return reinterpret_cast<HiddenData *>( reinterpret_cast<uint8_t *>( this ) - sizeof( HiddenData ) ); - } + static constexpr size_t HIDDEN_OFFSET = 32; // sizeof(HiddenData) aligned to 8 bytes + + HiddenData* hidden() { + std::byte* base = reinterpret_cast<std::byte*>(this); + // Force alignment to match shared_ptr requirements + static_assert(alignof(HiddenData) >= alignof(std::shared_ptr<void>), + "HiddenData must be aligned for shared_ptr"); + return reinterpret_cast<HiddenData*>(base - HIDDEN_OFFSET); + } + // HiddenData * hidden() + // { + // return reinterpret_cast<HiddenData *>( reinterpret_cast<uint8_t *>( this ) - sizeof( HiddenData ) ); + // } //actual data is allocated past this point }; diff --git a/csp/adapters/websocket.py b/csp/adapters/websocket.py index 3a595aeef..d1a5c1680 100644 --- a/csp/adapters/websocket.py +++ b/csp/adapters/websocket.py @@ -8,6 +8,7 @@ import csp from csp import ts +from csp.adapters.dynamic_adapter_utils import AdapterInfo from csp.adapters.status import Status from csp.adapters.utils import ( BytesMessageProtoMapper, @@ -17,20 +18,20 @@ RawBytesMessageMapper, RawTextMessageMapper, ) +from csp.adapters.websocket_types import ActionType, ConnectionRequest, WebsocketHeaderUpdate, WebsocketStatus from csp.impl.wiring import input_adapter_def, output_adapter_def, status_adapter_def from csp.impl.wiring.delayed_node import DelayedNodeWrapperDef from csp.lib import _websocketadapterimpl -from .dynamic_adapter_utils import AdapterInfo -from .websocket_types import ActionType, ConnectionRequest, WebsocketHeaderUpdate, WebsocketStatus # noqa - # InternalConnectionRequest, _ = ( + ActionType, BytesMessageProtoMapper, DateTimeType, JSONTextMessageMapper, RawBytesMessageMapper, RawTextMessageMapper, + WebsocketStatus, ) T = TypeVar("T") @@ -577,7 +578,7 @@ def update_headers(self, x: ts[List[WebsocketHeaderUpdate]]): def status(self, push_mode=csp.PushMode.NON_COLLAPSING): ts_type = Status - return status_adapter_def(self, ts_type, push_mode=push_mode) + return status_adapter_def(self, ts_type, push_mode) def _create(self, engine, memo): """method needs to return the wrapped c++ adapter manager""" diff --git a/csp/tests/adapters/test_websocket.py b/csp/tests/adapters/test_websocket.py index 6d23d5c2b..0b37376e6 100644 --- a/csp/tests/adapters/test_websocket.py +++ b/csp/tests/adapters/test_websocket.py @@ -2,84 +2,101 @@ import pytest import pytz import threading +import tornado.ioloop +import tornado.web +import tornado.websocket from contextlib import contextmanager from datetime import datetime, timedelta +from tornado.testing import bind_unused_port from typing import List, Optional, Type import csp from csp import ts - -if os.environ.get("CSP_TEST_WEBSOCKET"): - import tornado.ioloop - import tornado.web - import tornado.websocket - - from csp.adapters.websocket import ( - ActionType, - ConnectionRequest, - JSONTextMessageMapper, - RawTextMessageMapper, - Status, - WebsocketAdapterManager, - WebsocketHeaderUpdate, - WebsocketStatus, - ) - - class EchoWebsocketHandler(tornado.websocket.WebSocketHandler): - def on_message(self, msg): - # Carve-out to allow inspecting the headers - if msg == "header1": - msg = self.request.headers.get(msg, "") - elif not isinstance(msg, str) and msg.decode("utf-8") == "header1": - # Need this for bytes - msg = self.request.headers.get("header1", "") - return self.write_message(msg) - - @contextmanager - def create_tornado_server(port: int): - """Base context manager for creating a Tornado server in a thread""" - ready_event = threading.Event() - io_loop = None - app = None - io_thread = None - - def run_io_loop(): - nonlocal io_loop, app - io_loop = tornado.ioloop.IOLoop() - io_loop.make_current() - app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) +from csp.adapters.websocket import ( + ActionType, + ConnectionRequest, + JSONTextMessageMapper, + RawTextMessageMapper, + Status, + WebsocketAdapterManager, + WebsocketHeaderUpdate, + WebsocketStatus, +) + + +class EchoWebsocketHandler(tornado.websocket.WebSocketHandler): + def on_message(self, msg): + # Carve-out to allow inspecting the headers + if msg == "header1": + msg = self.request.headers.get(msg, "") + elif not isinstance(msg, str) and msg.decode("utf-8") == "header1": + # Need this for bytes + msg = self.request.headers.get("header1", "") + return self.write_message(msg) + + +@contextmanager +def create_tornado_server(port: int = None): + """Base context manager for creating a Tornado server in a thread + + Args: + port: Optional port number. If None, an unused port will be chosen. + + Returns: + Tuple containing (io_loop, app, io_thread, port) + """ + ready_event = threading.Event() + io_loop = None + app = None + io_thread = None + + # Get an unused port if none specified + if port is None: + sock, port = bind_unused_port() + sock.close() + + def run_io_loop(): + nonlocal io_loop, app + io_loop = tornado.ioloop.IOLoop() + io_loop.make_current() + app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) + try: app.listen(port) ready_event.set() io_loop.start() + except Exception as e: + ready_event.set() # Ensure we don't hang in case of error + raise - io_thread = threading.Thread(target=run_io_loop) - io_thread.start() - ready_event.wait() + io_thread = threading.Thread(target=run_io_loop) + io_thread.start() + ready_event.wait() + + try: + yield io_loop, app, io_thread, port + finally: + io_loop.add_callback(io_loop.stop) + if io_thread: + io_thread.join(timeout=5) + if io_thread.is_alive(): + raise RuntimeError("IOLoop failed to stop") - try: - yield io_loop, app, io_thread - finally: - io_loop.add_callback(io_loop.stop) - if io_thread: - io_thread.join(timeout=5) - if io_thread.is_alive(): - raise RuntimeError("IOLoop failed to stop") - - @contextmanager - def tornado_server(port: int = 8001): - """Simplified context manager that uses the base implementation""" - with create_tornado_server(port) as (_io_loop, _app, _io_thread): - yield + +@contextmanager +def tornado_server(): + """Simplified context manager that uses the base implementation with dynamic port""" + with create_tornado_server() as (_io_loop, _app, _io_thread, port): + yield port -@pytest.mark.skipif(os.environ.get("CSP_TEST_WEBSOCKET") is None, reason="'CSP_TEST_WEBSOCKET' env variable is not set") class TestWebsocket: @pytest.fixture(scope="class", autouse=True) def setup_tornado(self, request): - with create_tornado_server(8000) as (io_loop, app, io_thread): + with create_tornado_server() as (io_loop, app, io_thread, port): request.cls.io_loop = io_loop request.cls.app = app request.cls.io_thread = io_thread + request.cls.port = port # Make the port available to tests yield def test_send_recv_msg(self): @@ -90,7 +107,7 @@ def send_msg_on_open(status: ts[Status]) -> ts[str]: @csp.graph def g(): - ws = WebsocketAdapterManager("ws://localhost:8000/") + ws = WebsocketAdapterManager(f"ws://localhost:{self.port}/") status = ws.status() ws.send(send_msg_on_open(status)) recv = ws.subscribe(str, RawTextMessageMapper()) @@ -107,41 +124,44 @@ def test_headers(self, binary): def g(dynamic: bool): if dynamic: ws = WebsocketAdapterManager(dynamic=True, binary=binary) - # Connect with header conn_request1 = csp.const( [ ConnectionRequest( - uri="ws://localhost:8000/", on_connect_payload="header1", headers={"header1": "value1"} + uri=f"ws://localhost:{self.port}/", + on_connect_payload="header1", + headers={"header1": "value1"}, ) ] ) - # Disconnect to shutdown endpoint conn_request2 = csp.const( - [ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.DISCONNECT)], + [ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.DISCONNECT)], delay=timedelta(milliseconds=100), ) - # Reconnect to open endpoint with new headers conn_request3 = csp.const( [ ConnectionRequest( - uri="ws://localhost:8000/", on_connect_payload="header1", headers={"header1": "value2"} + uri=f"ws://localhost:{self.port}/", + on_connect_payload="header1", + headers={"header1": "value2"}, ) ], delay=timedelta(milliseconds=150), ) - conn_request3 = csp.const( - [ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.PING)], + conn_request4 = csp.const( + [ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.PING)], delay=timedelta(milliseconds=151), ) - conn_request4 = csp.const( + conn_request5 = csp.const( [ ConnectionRequest( - uri="ws://localhost:8000/", on_connect_payload="header1", headers={"header1": "value2"} + uri=f"ws://localhost:{self.port}/", + on_connect_payload="header1", + headers={"header1": "value2"}, ) ], delay=timedelta(milliseconds=200), ) - conn_req = csp.flatten([conn_request1, conn_request2, conn_request3, conn_request4]) + conn_req = csp.flatten([conn_request1, conn_request2, conn_request3, conn_request4, conn_request5]) status = ws.status() csp.add_graph_output("status", status) recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_req) @@ -150,7 +170,7 @@ def g(dynamic: bool): csp.stop_engine(stop) if not dynamic: - ws = WebsocketAdapterManager("ws://localhost:8000/", headers={"header1": "value1"}) + ws = WebsocketAdapterManager(f"ws://localhost:{self.port}/", headers={"header1": "value1"}) status = ws.status() send_msg = csp.sample(status, csp.const("header1")) to_send = csp.merge(send_msg, csp.const("header1", delay=timedelta(milliseconds=100))) @@ -160,7 +180,6 @@ def g(dynamic: bool): header_update = csp.const( [WebsocketHeaderUpdate(key="header1", value="value2")], delay=timedelta(milliseconds=50) ) - # Doesn' tick out since we don't disconnect ws.update_headers(header_update) status = ws.status() csp.add_graph_output("status", status) @@ -168,22 +187,6 @@ def g(dynamic: bool): csp.add_graph_output("recv", recv) csp.stop_engine(recv) - msgs = csp.run(g, dynamic=False, starttime=datetime.now(pytz.UTC), realtime=True) - assert msgs["recv"][0][1] == "value1" - assert len(msgs["status"]) == 1 - assert msgs["status"][0][1].status_code == WebsocketStatus.ACTIVE.value - - msgs = csp.run(g, dynamic=True, starttime=datetime.now(pytz.UTC), realtime=True) - assert msgs["recv"][0][1].uri == "ws://localhost:8000/" - assert msgs["recv"][1][1].uri == "ws://localhost:8000/" - assert msgs["recv"][0][1].msg == "value1" - assert msgs["recv"][1][1].msg == "value2" - - assert len(msgs["status"]) == 3 - assert msgs["status"][0][1].status_code == WebsocketStatus.ACTIVE.value - assert msgs["status"][1][1].status_code == WebsocketStatus.CLOSED.value - assert msgs["status"][2][1].status_code == WebsocketStatus.ACTIVE.value - @pytest.mark.parametrize("send_payload_subscribe", [True, False]) def test_send_recv_json_dynamic_on_connect_payload(self, send_payload_subscribe): class MsgStruct(csp.Struct): @@ -194,7 +197,7 @@ class MsgStruct(csp.Struct): def g(): ws = WebsocketAdapterManager(dynamic=True) conn_request = ConnectionRequest( - uri="ws://localhost:8000/", + uri=f"ws://localhost:{self.port}/", action=ActionType.CONNECT, on_connect_payload=MsgStruct(a=1234, b="im a string").to_json(), ) @@ -203,7 +206,7 @@ def g(): # The 'on_connect_payload sends the result ws.send(csp.null_ts(object), connection_request=csp.const([conn_request])) subscribe_connection_request = ( - [ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.CONNECT)] + [ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.CONNECT)] if not send_payload_subscribe else [conn_request] ) @@ -216,7 +219,7 @@ def g(): msgs = csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) obj = msgs["recv"][0][1] - assert obj.uri == "ws://localhost:8000/" + assert obj.uri == f"ws://localhost:{self.port}/" true_obj = obj.msg assert isinstance(true_obj, MsgStruct) assert true_obj.a == 1234 @@ -234,7 +237,7 @@ def send_msg_on_open(status: ts[Status]) -> ts[str]: @csp.graph def g(): - ws = WebsocketAdapterManager("ws://localhost:8000/") + ws = WebsocketAdapterManager(f"ws://localhost:{self.port}/") status = ws.status() ws.send(send_msg_on_open(status)) recv = ws.subscribe(MsgStruct, JSONTextMessageMapper()) @@ -273,7 +276,7 @@ def stop_on_all_or_timeout(msgs: ts[str], l: int = 50) -> ts[bool]: @csp.graph def g(n: int): - ws = WebsocketAdapterManager("ws://localhost:8000/") + ws = WebsocketAdapterManager(f"ws://localhost:{self.port}/") status = ws.status() ws.send(csp.flatten([send_msg_on_open(status, i) for i in range(n)])) recv = ws.subscribe(str, RawTextMessageMapper()) @@ -293,7 +296,7 @@ def g(): conn_request = csp.const( [ ConnectionRequest( - uri="ws://localhost:8000/", + uri=f"ws://localhost:{self.port}/", action=ActionType.CONNECT, ) ] @@ -337,15 +340,15 @@ def g(): if reconnect_immeditately: disconnect_reqs = [ - ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.DISCONNECT), - ConnectionRequest(uri="ws://localhost:8000/"), + ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.DISCONNECT), + ConnectionRequest(uri=f"ws://localhost:{self.port}/"), ] else: - disconnect_reqs = [ConnectionRequest(uri="ws://localhost:8000/", action=ActionType.DISCONNECT)] + disconnect_reqs = [ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.DISCONNECT)] conn_request = csp.curve( List[ConnectionRequest], [ - (timedelta(), [ConnectionRequest(uri="ws://localhost:8000/")]), + (timedelta(), [ConnectionRequest(uri=f"ws://localhost:{self.port}/")]), ( timedelta(milliseconds=100), disconnect_reqs, @@ -354,14 +357,14 @@ def g(): timedelta(milliseconds=350), [ ConnectionRequest( - uri="ws://localhost:8000/", + uri=f"ws://localhost:{self.port}/", headers={"dummy_key": "dummy_value"}, ), ], ), ], ) - const_conn_request = csp.const([ConnectionRequest(uri="ws://localhost:8000/")]) + const_conn_request = csp.const([ConnectionRequest(uri=f"ws://localhost:{self.port}/")]) val = csp.curve(int, [(timedelta(milliseconds=100, microseconds=1), 0), (timedelta(milliseconds=500), 1)]) hello = csp.apply(val, lambda x: f"hi world{x}", str) @@ -374,7 +377,7 @@ def g(): recv3 = ws.subscribe(str, RawTextMessageMapper(), connection_request=const_conn_request) no_persist_conn = ConnectionRequest( - uri="ws://localhost:8000/", persistent=False, on_connect_payload="hi non-persistent world!" + uri=f"ws://localhost:{self.port}/", persistent=False, on_connect_payload="hi non-persistent world!" ) recv4 = ws.subscribe( str, @@ -395,24 +398,24 @@ def g(): if not reconnect_immeditately: assert len(msgs["recv"]) == 1 assert msgs["recv"][0][1].msg == "hi world1" - assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + assert msgs["recv"][0][1].uri == f"ws://localhost:{self.port}/" else: assert len(msgs["recv"]) == 3 assert msgs["recv"][0][1].msg == "hi world0" - assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + assert msgs["recv"][0][1].uri == f"ws://localhost:{self.port}/" assert msgs["recv"][1][1].msg == "hi non-persistent world!" - assert msgs["recv"][1][1].uri == "ws://localhost:8000/" + assert msgs["recv"][1][1].uri == f"ws://localhost:{self.port}/" assert msgs["recv"][2][1].msg == "hi world1" - assert msgs["recv"][2][1].uri == "ws://localhost:8000/" + assert msgs["recv"][2][1].uri == f"ws://localhost:{self.port}/" # This subscribe call received all the messages assert len(msgs["recv3"]) == 3 assert msgs["recv3"][0][1].msg == "hi world0" - assert msgs["recv3"][0][1].uri == "ws://localhost:8000/" + assert msgs["recv3"][0][1].uri == f"ws://localhost:{self.port}/" assert msgs["recv3"][1][1].msg == "hi non-persistent world!" - assert msgs["recv3"][1][1].uri == "ws://localhost:8000/" + assert msgs["recv3"][1][1].uri == f"ws://localhost:{self.port}/" assert msgs["recv3"][2][1].msg == "hi world1" - assert msgs["recv3"][2][1].uri == "ws://localhost:8000/" + assert msgs["recv3"][2][1].uri == f"ws://localhost:{self.port}/" def test_dynamic_pruned_subscribe(self): @csp.graph @@ -421,7 +424,7 @@ def g(): conn_request = csp.const( [ ConnectionRequest( - uri="ws://localhost:8000/", + uri=f"ws://localhost:{self.port}/", action=ActionType.CONNECT, ) ] @@ -444,7 +447,7 @@ def g(): assert len(msgs["recv"]) == 1 # Only the second message is received assert msgs["recv"][0][1].msg == "hi world1" - assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + assert msgs["recv"][0][1].uri == f"ws://localhost:{self.port}/" def test_dynamic_multiple_subscribers(self): @csp.node @@ -453,24 +456,28 @@ def send_on_status(status: ts[Status], uri: str, val: str) -> ts[str]: if uri in status.msg and status.status_code == WebsocketStatus.ACTIVE.value: return val - with tornado_server(): - # We do this to only spawn the tornado server once for both options + with tornado_server() as port2: # Get a second dynamic port + @csp.graph def g(use_on_connect_payload: bool): ws = WebsocketAdapterManager(dynamic=True) if use_on_connect_payload: conn_request1 = csp.const( - [ConnectionRequest(uri="ws://localhost:8000/", on_connect_payload="hey world from 8000")] + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", on_connect_payload="hey world from main" + ) + ] ) conn_request2 = csp.const( - [ConnectionRequest(uri="ws://localhost:8001/", on_connect_payload="hey world from 8001")] + [ConnectionRequest(uri=f"ws://localhost:{port2}/", on_connect_payload="hey world from second")] ) else: - conn_request1 = csp.const([ConnectionRequest(uri="ws://localhost:8000/")]) - conn_request2 = csp.const([ConnectionRequest(uri="ws://localhost:8001/")]) + conn_request1 = csp.const([ConnectionRequest(uri=f"ws://localhost:{self.port}/")]) + conn_request2 = csp.const([ConnectionRequest(uri=f"ws://localhost:{port2}/")]) status = ws.status() - to_send = send_on_status(status, "ws://localhost:8000/", "hey world from 8000") - to_send2 = send_on_status(status, "ws://localhost:8001/", "hey world from 8001") + to_send = send_on_status(status, f"ws://localhost:{self.port}/", "hey world from main") + to_send2 = send_on_status(status, f"ws://localhost:{port2}/", "hey world from second") ws.send(to_send, connection_request=conn_request1) ws.send(to_send2, connection_request=conn_request2) @@ -493,11 +500,11 @@ def g(use_on_connect_payload: bool): realtime=True, ) assert len(msgs["recv"]) == 1 - assert msgs["recv"][0][1].msg == "hey world from 8000" - assert msgs["recv"][0][1].uri == "ws://localhost:8000/" + assert msgs["recv"][0][1].msg == "hey world from main" + assert msgs["recv"][0][1].uri == f"ws://localhost:{self.port}/" assert len(msgs["recv2"]) == 1 - assert msgs["recv2"][0][1].msg == "hey world from 8001" - assert msgs["recv2"][0][1].uri == "ws://localhost:8001/" + assert msgs["recv2"][0][1].msg == "hey world from second" + assert msgs["recv2"][0][1].uri == f"ws://localhost:{port2}/" @pytest.mark.parametrize("dynamic", [False, True]) def test_send_recv_burst_json(self, dynamic): @@ -522,7 +529,7 @@ def g(): connection_request=csp.const( [ ConnectionRequest( - uri="ws://localhost:8000/", + uri=f"ws://localhost:{self.port}/", on_connect_payload=MsgStruct(a=1234, b="im a string").to_json(), ) ] @@ -530,7 +537,7 @@ def g(): ) recv = csp.apply(wrapped_recv, lambda vals: [v.msg for v in vals], List[MsgStruct]) else: - ws = WebsocketAdapterManager("ws://localhost:8000/") + ws = WebsocketAdapterManager(f"ws://localhost:{self.port}/") status = ws.status() ws.send(csp.apply(status, lambda _x: MsgStruct(a=1234, b="im a string").to_json(), str)) recv = ws.subscribe(MsgStruct, JSONTextMessageMapper(), push_mode=csp.PushMode.BURST) diff --git a/pyproject.toml b/pyproject.toml index e790a1b21..07acc860d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ develop = [ "sqlalchemy", # db "threadpoolctl", # test_random "tornado", # profiler, perspective, websocket + "python-rapidjson", # websocket # type checking "pydantic>=2", ] From 628ceb69c84b996ed95300296e80543d18c25237 Mon Sep 17 00:00:00 2001 From: Nijat K <nijat.khanbabayev@gmail.com> Date: Thu, 28 Nov 2024 21:32:26 -0500 Subject: [PATCH 4/8] Disable dynamic_cast optimization introduced in clang17 Signed-off-by: Nijat K <nijat.khanbabayev@gmail.com> --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index f08712a63..6f9d2d000 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,6 +204,11 @@ else() -Wno-maybe-uninitialized \ ") endif() + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 17) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-assume-unique-vtables") + endif() + endif() endif() endif() From 639c5aa03476ca2ea8ac2a5f854cc53fad41ac5c Mon Sep 17 00:00:00 2001 From: Nijat K <nijat.khanbabayev@gmail.com> Date: Thu, 28 Nov 2024 21:46:56 -0500 Subject: [PATCH 5/8] Remove comments, more delay for test Signed-off-by: Nijat K <nijat.khanbabayev@gmail.com> --- Makefile | 4 +- cpp/csp/engine/StatusAdapter.h | 2 +- cpp/csp/engine/Struct.h | 99 ++++--------------- .../python/adapters/websocketadapterimpl.cpp | 16 --- csp/tests/adapters/test_websocket.py | 23 ++--- 5 files changed, 29 insertions(+), 115 deletions(-) diff --git a/Makefile b/Makefile index 52bda34c3..32827d4eb 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ build-debug: ## build the library ( DEBUG ) - May need a make clean when switch SKBUILD_CONFIGURE_OPTIONS="" DEBUG=1 python setup.py build build_ext --inplace build-conda: ## build the library in Conda - CSP_USE_CCACHE=0 python setup.py build build_ext --csp-no-vcpkg --inplace + python setup.py build build_ext --csp-no-vcpkg --inplace install: ## install library python -m pip install . @@ -83,7 +83,7 @@ checks: check TEST_ARGS := test-py: ## Clean and Make unit tests - python -m pytest -vv -s csp/tests --junitxml=junit.xml $(TEST_ARGS) + python -m pytest -v csp/tests --junitxml=junit.xml $(TEST_ARGS) test-cpp: ## Make C++ unit tests ifneq ($(OS),Windows_NT) diff --git a/cpp/csp/engine/StatusAdapter.h b/cpp/csp/engine/StatusAdapter.h index 8545b7619..f243c7c70 100644 --- a/cpp/csp/engine/StatusAdapter.h +++ b/cpp/csp/engine/StatusAdapter.h @@ -44,7 +44,7 @@ class StatusAdapter : public PushInputAdapter m_statusAccess.meta = meta; m_statusAccess.level = meta -> getMetaField<int64_t>( "level", "Status" ); m_statusAccess.statusCode = meta -> getMetaField<int64_t>( "status_code", "Status" ); - m_statusAccess.msg = meta -> getMetaField<typename csp::StringStructField::CType>( "msg", "Status" ); + m_statusAccess.msg = meta -> getMetaField<std::string>( "msg", "Status" ); } void pushStatus( int64_t level, int64_t statusCode, const std::string & msg, PushBatch *batch = nullptr ) diff --git a/cpp/csp/engine/Struct.h b/cpp/csp/engine/Struct.h index 78d6a9265..e0653e4a3 100644 --- a/cpp/csp/engine/Struct.h +++ b/cpp/csp/engine/Struct.h @@ -7,7 +7,6 @@ #include <string> #include <vector> #include <unordered_map> -#include <iostream> namespace csp { @@ -669,69 +668,16 @@ class StructMeta : public std::enable_shared_from_this<StructMeta> }; template<typename T> -std::shared_ptr<typename StructField::upcast<T>::type> StructMeta::getMetaField(const char* fieldname, const char* expectedtype) { - std::cout << "\n=== getMetaField Debug ===\n"; - std::cout << "1. Looking for field: " << fieldname << "\n"; - - auto field_ = field(fieldname); - if(!field_) { - std::cout << "2. Field not found!\n"; - CSP_THROW(TypeError, "Struct type " << name() << " missing required field " << fieldname); - } - - std::cout << "2. Field found\n"; - std::cout << "3. Field name from object: " << field_->fieldname() << "\n"; - std::cout << "4. Field type from CspType: " << field_->type()->type() << "\n"; - std::cout << "5. Expected type: " << CspType::Type::fromCType<T>::type << "\n"; - - // Memory layout & pointer checks - const StructField* field_ptr = field_.get(); - std::cout << "6. Field ptr value: " << field_ptr << "\n"; - std::cout << "7. Field use count: " << field_.use_count() << "\n"; - - // Detailed field information - if(field_ptr) { - std::cout << "8. Field metadata:\n"; - std::cout << " - Field offset: " << field_ptr->offset() << "\n"; - std::cout << " - Field size: " << field_ptr->size() << "\n"; - std::cout << " - Field alignment: " << field_ptr->alignment() << "\n"; - std::cout << " - Field mask offset: " << field_ptr->maskOffset() << "\n"; - std::cout << " - Field mask bit: " << static_cast<int>(field_ptr->maskBit()) << "\n"; - - // Type verification - std::cout << "9. Type checks:\n"; - std::cout << " - Original type: " << typeid(field_).name() << "\n"; - std::cout << " - Target type: " << typeid(typename StructField::upcast<T>::type).name() << "\n"; - std::cout << " - Is native: " << field_ptr->isNative() << "\n"; - - // Test various casts - std::cout << "10. Detailed cast tests:\n"; - std::cout << " Base classes:\n"; - std::cout << " - As StructField*: " << (dynamic_cast<const StructField*>(field_ptr) != nullptr) << "\n"; - std::cout << " - As NonNativeStructField*: " << (dynamic_cast<const NonNativeStructField*>(field_ptr) != nullptr) << "\n"; - std::cout << " Non-native implementations:\n"; - std::cout << " - As StringStructField*: " << (dynamic_cast<const StringStructField*>(field_ptr) != nullptr) << "\n"; - std::cout << " - As DialectGenericStructField*: " << (dynamic_cast<const DialectGenericStructField*>(field_ptr) != nullptr) << "\n"; - std::cout << " - As ArrayStructField<std::string>*: " << (dynamic_cast<const ArrayStructField<std::vector<std::string>>*>(field_ptr) != nullptr) << "\n"; - std::cout << " Native field test:\n"; - std::cout << " - As NativeStructField<int64_t>*: " << (dynamic_cast<const NativeStructField<int64_t>*>(field_ptr) != nullptr) << "\n"; -} - - using TargetType = typename StructField::upcast<T>::type; - auto typedfield = std::dynamic_pointer_cast<TargetType>(field_); - std::cout << "11. Final dynamic_cast result: " << (typedfield ? "success" : "failure") << "\n"; +std::shared_ptr<typename StructField::upcast<T>::type> StructMeta::getMetaField( const char * fieldname, const char * expectedtype ) +{ + auto field_ = field( fieldname ); + if( !field_ ) + CSP_THROW( TypeError, "Struct type " << name() << " missing required field " << fieldname << " for " << expectedtype ); - if(!typedfield) { - std::cout << "12. FAILED CAST DETAILS:\n"; - std::cout << " - Source type: " << typeid(StructField).name() << "\n"; - std::cout << " - Target type: " << typeid(TargetType).name() << "\n"; - - CSP_THROW(TypeError, expectedtype << " - provided struct type " << name() - << " expected type " << CspType::Type::fromCType<T>::type - << " for field " << fieldname - << " but got type " << field_->type()->type() - << " for " << expectedtype); - } + std::shared_ptr<typename StructField::upcast<T>::type> typedfield = std::dynamic_pointer_cast<typename StructField::upcast<T>::type>( field_ ); + if( !typedfield ) + CSP_THROW( TypeError, expectedtype << " - provided struct type " << name() << " expected type " << CspType::Type::fromCType<T>::type << " for field " << fieldname + << " but got type " << field_ -> type() -> type() << " for " << expectedtype ); return typedfield; } @@ -827,11 +773,11 @@ class Struct friend class StructMeta; //Note these members are not included on size(), they're stored before "this" ptr ( see operator new / delete ) - struct alignas(8) HiddenData { - alignas(8) size_t refcount; // 8 bytes at 0x0 - alignas(8) std::shared_ptr<const StructMeta> meta; // 16 bytes at 0x8 - alignas(8) void* dialectPtr; // 8 bytes at 0x18 - // Total: 32 bytes + struct HiddenData + { + size_t refcount; + std::shared_ptr<const StructMeta> meta; + void * dialectPtr; }; const HiddenData * hidden() const @@ -839,19 +785,10 @@ class Struct return const_cast<Struct *>( this ) -> hidden(); } - static constexpr size_t HIDDEN_OFFSET = 32; // sizeof(HiddenData) aligned to 8 bytes - - HiddenData* hidden() { - std::byte* base = reinterpret_cast<std::byte*>(this); - // Force alignment to match shared_ptr requirements - static_assert(alignof(HiddenData) >= alignof(std::shared_ptr<void>), - "HiddenData must be aligned for shared_ptr"); - return reinterpret_cast<HiddenData*>(base - HIDDEN_OFFSET); - } - // HiddenData * hidden() - // { - // return reinterpret_cast<HiddenData *>( reinterpret_cast<uint8_t *>( this ) - sizeof( HiddenData ) ); - // } + HiddenData * hidden() + { + return reinterpret_cast<HiddenData *>( reinterpret_cast<uint8_t *>( this ) - sizeof( HiddenData ) ); + } //actual data is allocated past this point }; diff --git a/cpp/csp/python/adapters/websocketadapterimpl.cpp b/cpp/csp/python/adapters/websocketadapterimpl.cpp index db6591271..557ee4454 100644 --- a/cpp/csp/python/adapters/websocketadapterimpl.cpp +++ b/cpp/csp/python/adapters/websocketadapterimpl.cpp @@ -8,7 +8,6 @@ #include <csp/python/PyInputAdapterWrapper.h> #include <csp/python/PyOutputAdapterWrapper.h> -#include <iostream> using namespace csp::adapters::websocket; namespace csp::python @@ -63,9 +62,7 @@ static OutputAdapter * create_websocket_header_update_adapter( csp::AdapterManag static OutputAdapter * create_websocket_connection_request_adapter( csp::AdapterManager * manager, PyEngine * pyengine, PyObject * args ) { - // std::cout << "hereeeee33ee" << "\n"; PyObject * pyProperties; - // PyObject * type; auto * websocketManager = dynamic_cast<ClientAdapterManager*>( manager ); if( !websocketManager ) CSP_THROW( TypeError, "Expected WebsocketClientAdapterManager" ); @@ -73,20 +70,7 @@ static OutputAdapter * create_websocket_connection_request_adapter( csp::Adapter if( !PyArg_ParseTuple( args, "O!", &PyDict_Type, &pyProperties ) ) CSP_THROW( PythonPassthrough, "" ); - // std::cout << "hereeeee334444ee" << "\n"; return websocketManager -> getConnectionRequestAdapter( fromPython<Dictionary>( pyProperties ) ); - - - // TODO - // Here I think we should have a websocket connection manager - // that will handle the connections and endpoint management - // It will create the connection request output adapter - // That output adapter, when it ticks, with a list of python Dictionary - // will then use the boost beast 'post' function to schedule, on the - // io context, a callback to process that dict (on the websocket connection manager!!!) and handle the endpoint manipulation appropriately - - // that websocket connection manager will run the thread with the io context - // being run. Move it away from clientAdapterManager } REGISTER_ADAPTER_MANAGER( _websocket_adapter_manager, create_websocket_adapter_manager ); diff --git a/csp/tests/adapters/test_websocket.py b/csp/tests/adapters/test_websocket.py index 0b37376e6..db2b4ae72 100644 --- a/csp/tests/adapters/test_websocket.py +++ b/csp/tests/adapters/test_websocket.py @@ -58,7 +58,6 @@ def create_tornado_server(port: int = None): def run_io_loop(): nonlocal io_loop, app io_loop = tornado.ioloop.IOLoop() - io_loop.make_current() app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) try: app.listen(port) @@ -301,9 +300,9 @@ def g(): ) ] ) - val = csp.curve(int, [(timedelta(milliseconds=50), 0), (timedelta(milliseconds=500), 1)]) + val = csp.curve(int, [(timedelta(milliseconds=100), 0), (timedelta(milliseconds=500), 1)]) hello = csp.apply(val, lambda x: f"hi world{x}", str) - delayed_conn_req = csp.delay(conn_request, delay=timedelta(milliseconds=100)) + delayed_conn_req = csp.delay(conn_request, delay=timedelta(milliseconds=300)) # We connect immediately and send out the hello message ws.send(hello, connection_request=conn_request) @@ -328,12 +327,6 @@ def g(): @pytest.mark.parametrize("reconnect_immeditately", [False, True]) def test_dynamic_disconnect_connect_pruned_subscribe(self, reconnect_immeditately): - @csp.node - def prevent_prune(objs: ts[str]): - if csp.ticked(objs): - # Does nothing but makes sure it's not pruned - ... - @csp.graph def g(): ws = WebsocketAdapterManager(dynamic=True) @@ -354,7 +347,7 @@ def g(): disconnect_reqs, ), ( - timedelta(milliseconds=350), + timedelta(milliseconds=700), [ ConnectionRequest( uri=f"ws://localhost:{self.port}/", @@ -365,7 +358,7 @@ def g(): ], ) const_conn_request = csp.const([ConnectionRequest(uri=f"ws://localhost:{self.port}/")]) - val = csp.curve(int, [(timedelta(milliseconds=100, microseconds=1), 0), (timedelta(milliseconds=500), 1)]) + val = csp.curve(int, [(timedelta(milliseconds=300, microseconds=1), 0), (timedelta(milliseconds=900), 1)]) hello = csp.apply(val, lambda x: f"hi world{x}", str) # We connect immediately and send out the hello message @@ -382,7 +375,7 @@ def g(): recv4 = ws.subscribe( str, RawTextMessageMapper(), - connection_request=csp.const([no_persist_conn], delay=timedelta(milliseconds=250)), + connection_request=csp.const([no_persist_conn], delay=timedelta(milliseconds=500)), ) csp.add_graph_output("recv", recv) @@ -391,7 +384,7 @@ def g(): end = csp.filter(csp.count(recv3) == 3, recv3) csp.stop_engine(end) - msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=1), realtime=True) + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=2), realtime=True) # Did not persist, so did not receive any messages assert len(msgs["recv4"]) == 0 # Only the second message is received, since we disonnect before the first one is sent @@ -429,9 +422,9 @@ def g(): ) ] ) - val = csp.curve(int, [(timedelta(milliseconds=50), 0), (timedelta(milliseconds=500), 1)]) + val = csp.curve(int, [(timedelta(milliseconds=100), 0), (timedelta(milliseconds=600), 1)]) hello = csp.apply(val, lambda x: f"hi world{x}", str) - delayed_conn_req = csp.delay(conn_request, delay=timedelta(milliseconds=100)) + delayed_conn_req = csp.delay(conn_request, delay=timedelta(milliseconds=400)) # We connect immediately and send out the hello message ws.send(hello, connection_request=conn_request) From 8ca822c487afe6a5f71b980582fa982a49689026 Mon Sep 17 00:00:00 2001 From: Nijat K <nijat.khanbabayev@gmail.com> Date: Tue, 3 Dec 2024 16:16:27 -0500 Subject: [PATCH 6/8] Use rapidjson to parse headers, no more dict Signed-off-by: Nijat K <nijat.khanbabayev@gmail.com> --- .../ClientConnectionRequestAdapter.cpp | 30 ++++++++++++--- .../adapters/websocket/WebsocketEndpoint.cpp | 30 ++++++++++++++- .../adapters/websocket/WebsocketEndpoint.h | 2 +- .../websocket/WebsocketEndpointManager.cpp | 7 +--- csp/adapters/websocket.py | 38 +++++++++---------- csp/adapters/websocket_types.py | 22 +++++++++++ 6 files changed, 94 insertions(+), 35 deletions(-) diff --git a/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp index 98b76c4da..2bef1ffb5 100644 --- a/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp @@ -1,6 +1,4 @@ #include <csp/adapters/websocket/ClientConnectionRequestAdapter.h> -#include <csp/python/Conversions.h> -#include <Python.h> namespace csp::adapters::websocket { @@ -31,15 +29,35 @@ void ClientConnectionRequestAdapter::executeImpl() if (unlikely(m_isPruned)) return; - auto raw_val = input()->lastValueTyped<PyObject*>(); - auto val = python::fromPython<std::vector<Dictionary>>(raw_val); + std::vector<Dictionary> properties_list; + for (auto& request : input()->lastValueTyped<std::vector<InternalConnectionRequest::Ptr>>()) { + if (!request->allFieldsSet()) + CSP_THROW(TypeError, "All fields must be set in InternalConnectionRequest"); + + Dictionary dict; + dict.update("host", request->host()); + dict.update("port", request->port()); + dict.update("route", request->route()); + dict.update("uri", request->uri()); + dict.update("use_ssl", request->use_ssl()); + dict.update("reconnect_interval", request->reconnect_interval()); + dict.update("persistent", request->persistent()); + + dict.update("headers", request -> headers() ); + dict.update("on_connect_payload", request->on_connect_payload()); + dict.update("action", request->action()); + dict.update("dynamic", request->dynamic()); + dict.update("binary", request->binary()); + + properties_list.push_back(std::move(dict)); + } // We intentionally post here, we want the thread running // the strand to handle the connection request. We want to keep // all updates to internal data structures at graph run-time // to that thread. - boost::asio::post(m_strand, [this, val=std::move(val)]() { - for(const auto& conn_req: val) { + boost::asio::post(m_strand, [this, properties_list=std::move(properties_list)]() { + for(const auto& conn_req: properties_list) { m_websocketManager->handleConnectionRequest(conn_req, m_callerId, m_isSubscribe); } }); diff --git a/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp b/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp index 6bb256488..93f6e8918 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp +++ b/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp @@ -1,4 +1,5 @@ #include <csp/adapters/websocket/WebsocketEndpoint.h> +#include <rapidjson/document.h> namespace csp::adapters::websocket { using namespace csp; @@ -6,9 +7,16 @@ using namespace csp; WebsocketEndpoint::WebsocketEndpoint( net::io_context& ioc, Dictionary properties -) : m_properties(std::make_shared<Dictionary>(std::move(properties))), +) : m_properties( std::make_shared<Dictionary>( std::move( properties ) ) ), m_ioc(ioc) -{ }; +{ + std::string headerProps = m_properties->get<std::string>("headers"); + // Create new empty headers dictionary + auto headers = std::make_shared<Dictionary>(); + m_properties->update("headers", headers); + // Update with any existing header properties + updateHeaders(headerProps); +} void WebsocketEndpoint::setOnOpen(void_cb on_open) { m_on_open = std::move(on_open); } void WebsocketEndpoint::setOnFail(string_cb on_fail) @@ -76,6 +84,24 @@ void WebsocketEndpoint::updateHeaders(csp::Dictionary properties){ } } +void WebsocketEndpoint::updateHeaders(const std::string& properties) { + if( properties.empty() ) + return; + DictionaryPtr headers = m_properties->get<DictionaryPtr>("headers"); + rapidjson::Document doc; + doc.Parse(properties.c_str()); + if (doc.IsObject()) { + // Windows builds complained with range loop + for (auto it = doc.MemberBegin(); it != doc.MemberEnd(); ++it) { + if (it->value.IsString()) { + std::string key = it->name.GetString(); + std::string value = it->value.GetString(); + headers->update(key, std::move(value)); + } + } + } +} + std::shared_ptr<Dictionary> WebsocketEndpoint::getProperties() { return m_properties; } diff --git a/cpp/csp/adapters/websocket/WebsocketEndpoint.h b/cpp/csp/adapters/websocket/WebsocketEndpoint.h index 502d3eb59..aafc237d1 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpoint.h +++ b/cpp/csp/adapters/websocket/WebsocketEndpoint.h @@ -379,8 +379,8 @@ class WebsocketEndpoint { void setOnClose(void_cb on_close); void setOnSendFail(string_cb on_send_fail); void updateHeaders(Dictionary properties); + void updateHeaders(const std::string& properties); std::shared_ptr<Dictionary> getProperties(); - // Dictionary& getProperties(); void run(); void stop( bool stop_ioc = true); void send(const std::string& s); diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp index 6df8c8732..6cede55dd 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp @@ -241,12 +241,8 @@ void WebsocketEndpointManager::handleConnectionRequest(const Dictionary & proper // This should only get called from the thread running // m_ioc. This allows us to avoid locks on internal data // structures - // std::cout << " HEY YA\n"; auto endpoint_id = properties.get<std::string>("uri"); - // std::cout << endpoint_id; autogen::ActionType action = autogen::ActionType::create( properties.get<std::string>("action") ); - // std::cout << action.asString() << "\n"; - // // Change headers if needed here! switch(action.enum_value()) { case autogen::ActionType::enum_::CONNECT: { auto persistent = properties.get<bool>("persistent"); @@ -289,10 +285,9 @@ void WebsocketEndpointManager::handleConnectionRequest(const Dictionary & proper m_endpoints[endpoint_id]->send(payload); // Conscious decision to let non-persisten connection // results to update the header - auto headers = *properties.get<DictionaryPtr>("headers"); + auto headers = properties.get<std::string>("headers"); m_endpoints[endpoint_id]->updateHeaders(std::move(headers)); } - // } break; } diff --git a/csp/adapters/websocket.py b/csp/adapters/websocket.py index d1a5c1680..1a88f48c2 100644 --- a/csp/adapters/websocket.py +++ b/csp/adapters/websocket.py @@ -18,7 +18,13 @@ RawBytesMessageMapper, RawTextMessageMapper, ) -from csp.adapters.websocket_types import ActionType, ConnectionRequest, WebsocketHeaderUpdate, WebsocketStatus +from csp.adapters.websocket_types import ( + ActionType, + ConnectionRequest, + InternalConnectionRequest, + WebsocketHeaderUpdate, + WebsocketStatus, +) from csp.impl.wiring import input_adapter_def, output_adapter_def, status_adapter_def from csp.impl.wiring.delayed_node import DelayedNodeWrapperDef from csp.lib import _websocketadapterimpl @@ -396,18 +402,6 @@ def _instantiate(self): _launch_application(self._port, manager, csp.const("stub")) -# Maybe, we can have the Adapter manager have all the connections -# If we get a new connection request, we include that adapter for the -# subscriptions. When we pop it, we remove it. -# Then, each edge will effectively be independent. -# Maybe. have each websocket push to a shared queue, then from there we -# pass it along to all edges ("input adapters") that are subscribed to it - -# Ok, maybe, let's keep it at just 1 subscribe and send call. -# However, we can subscribe to the send and subscribe calls separately. -# We just have to keep track of the Endpoints we have, and - - class WebsocketAdapterManager: """ Can subscribe dynamically via ts[List[ConnectionRequest]] @@ -453,7 +447,7 @@ def __init__( connection_request = ConnectionRequest( uri=uri, reconnect_interval=reconnect_interval, headers=headers or {} ) - self._properties.update(self._get_properties(connection_request)) + self._properties.update(self._get_properties(connection_request).to_dict()) # This is a counter that will be used to identify every function call # We keep track of the subscribes and sends separately @@ -467,7 +461,7 @@ def __init__( def _dynamic(self): return self._properties.get("dynamic", False) - def _get_properties(self, conn_request: ConnectionRequest) -> dict: + def _get_properties(self, conn_request: ConnectionRequest) -> InternalConnectionRequest: uri = conn_request.uri reconnect_interval = conn_request.reconnect_interval @@ -476,14 +470,14 @@ def _get_properties(self, conn_request: ConnectionRequest) -> dict: if resp.hostname is None: raise ValueError(f"Failed to parse host from URI: {uri}") - res = dict( + res = InternalConnectionRequest( host=resp.hostname, # if no port is explicitly present in the uri, the resp.port is None port=_sanitize_port(uri, resp.port), route=resp.path or "/", # resource shouldn't be empty string use_ssl=uri.startswith("wss"), reconnect_interval=reconnect_interval, - headers=conn_request.headers, + headers=rapidjson.dumps(conn_request.headers) if conn_request.headers else "", persistent=conn_request.persistent, action=conn_request.action.name, on_connect_payload=conn_request.on_connect_payload, @@ -537,7 +531,9 @@ def subscribe( adapter_props = AdapterInfo(caller_id=caller_id, is_subscribe=True).model_dump() connection_request = csp.null_ts(List[ConnectionRequest]) if connection_request is None else connection_request request_dict = csp.apply( - connection_request, lambda conn_reqs: [self._get_properties(conn_req) for conn_req in conn_reqs], list + connection_request, + lambda conn_reqs: [self._get_properties(conn_req) for conn_req in conn_reqs], + List[InternalConnectionRequest], ) # Output adapter to handle connection requests _websocket_connection_request_adapter_def(self, request_dict, adapter_props) @@ -566,7 +562,9 @@ def send(self, x: ts["T"], connection_request: Optional[ts[List[ConnectionReques adapter_props = AdapterInfo(caller_id=caller_id, is_subscribe=False).model_dump() connection_request = csp.null_ts(List[ConnectionRequest]) if connection_request is None else connection_request request_dict = csp.apply( - connection_request, lambda conn_reqs: [self._get_properties(conn_req) for conn_req in conn_reqs], list + connection_request, + lambda conn_reqs: [self._get_properties(conn_req) for conn_req in conn_reqs], + List[InternalConnectionRequest], ) _websocket_connection_request_adapter_def(self, request_dict, adapter_props) return _websocket_output_adapter_def(self, x, adapter_props) @@ -614,6 +612,6 @@ def _create(self, engine, memo): "websocket_connection_request_adapter", _websocketadapterimpl._websocket_connection_request_adapter, WebsocketAdapterManager, - input=ts[list], # needed, List[dict] didn't work on c++ level + input=ts[List[InternalConnectionRequest]], # needed, List[dict] didn't work on c++ level properties=dict, ) diff --git a/csp/adapters/websocket_types.py b/csp/adapters/websocket_types.py index 314d2d8a9..27eb0ecbd 100644 --- a/csp/adapters/websocket_types.py +++ b/csp/adapters/websocket_types.py @@ -34,3 +34,25 @@ class ConnectionRequest(Struct): reconnect_interval: timedelta = timedelta(seconds=2) on_connect_payload: str = "" # message to send on connect headers: Dict[str, str] = {} + + +# Only used internally +class InternalConnectionRequest(Struct): + host: str # Hostname parsed from the URI + port: str # Port number for the connection (parsed and sanitized from URI) + route: str # Resource path from URI, defaults to "/" if empty + uri: str # Complete original URI string + + # Connection behavior + use_ssl: bool # Whether to use secure WebSocket (wss://) + reconnect_interval: timedelta # Time to wait between reconnection attempts + persistent: bool # Whether to maintain a persistent connection + + # Headers and payloads + headers: str # HTTP headers for the connection as json string + on_connect_payload: str # Message to send when connection is established + + # Connection metadata + action: str # Connection action type (Connect, Disconnect, Ping, etc) + dynamic: bool # Whether the connection is dynamic + binary: bool # Whether to use binary mode for the connection From d6f288c0f9585b8d58fa1b7331a7de3b2fbc5857 Mon Sep 17 00:00:00 2001 From: Nijat K <nijat.khanbabayev@gmail.com> Date: Tue, 10 Dec 2024 16:06:40 -0500 Subject: [PATCH 7/8] Dont use tuple to processMessage Signed-off-by: Nijat K <nijat.khanbabayev@gmail.com> --- .../adapters/websocket/ClientInputAdapter.cpp | 19 +-------------- .../adapters/websocket/ClientInputAdapter.h | 3 +-- .../websocket/WebsocketEndpointManager.cpp | 2 +- cpp/csp/python/Conversions.h | 24 ------------------- csp/adapters/websocket_types.py | 2 +- 5 files changed, 4 insertions(+), 46 deletions(-) diff --git a/cpp/csp/adapters/websocket/ClientInputAdapter.cpp b/cpp/csp/adapters/websocket/ClientInputAdapter.cpp index cff97f11b..103b63aff 100644 --- a/cpp/csp/adapters/websocket/ClientInputAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientInputAdapter.cpp @@ -32,25 +32,8 @@ ClientInputAdapter::ClientInputAdapter( m_converter = adapters::utils::MessageStructConverterCache::instance().create( type, properties ); }; -void ClientInputAdapter::processMessage( void* c, size_t t, PushBatch* batch ) +void ClientInputAdapter::processMessage( const std::string& source, void * c, size_t t, PushBatch* batch ) { - - if( dataType() -> type() == CspType::Type::STRUCT ) - { - auto tick = m_converter -> asStruct( c, t ); - pushTick( std::move(tick), batch ); - } else if ( dataType() -> type() == CspType::Type::STRING ) - { - pushTick( std::string((char const*)c, t), batch ); - } - -} - -void ClientInputAdapter::processMessage( std::tuple<std::string, void*> data, size_t t, PushBatch* batch ) -{ - // Extract the source string and data pointer from tuple - std::string source = std::get<0>(data); - void* c = std::get<1>(data); if ( m_dynamic ){ auto& actual_type = static_cast<const CspStructType &>( *dataType() ); auto& nested_type = actual_type.meta()-> field( "msg" ) -> type(); diff --git a/cpp/csp/adapters/websocket/ClientInputAdapter.h b/cpp/csp/adapters/websocket/ClientInputAdapter.h index a5ceda58e..0f5a4223b 100644 --- a/cpp/csp/adapters/websocket/ClientInputAdapter.h +++ b/cpp/csp/adapters/websocket/ClientInputAdapter.h @@ -20,8 +20,7 @@ class ClientInputAdapter final: public PushInputAdapter { bool dynamic ); - void processMessage( void* c, size_t t, PushBatch* batch ); - void processMessage( std::tuple<std::string, void*> data, size_t t, PushBatch* batch ); + void processMessage( const std::string& source, void * c, size_t t, PushBatch* batch ); private: adapters::utils::MessageStructConverterPtr m_converter; diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp index 6cede55dd..32f78f3ca 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp @@ -176,7 +176,7 @@ void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id, std::vector<uint8_t> data_copy(static_cast<uint8_t*>(data), static_cast<uint8_t*>(data) + len); auto tup = std::tuple<std::string, void*>{endpoint_id, data_copy.data()}; - m_inputAdapters[consumer_id] -> processMessage( std::move(tup), len, &batch ); + m_inputAdapters[consumer_id] -> processMessage( endpoint_id, data_copy.data(), len, &batch ); } } }); diff --git a/cpp/csp/python/Conversions.h b/cpp/csp/python/Conversions.h index a3ada6b2d..2422aaea1 100644 --- a/cpp/csp/python/Conversions.h +++ b/cpp/csp/python/Conversions.h @@ -666,30 +666,6 @@ inline Dictionary fromPython( PyObject * o ) return out; } -template<> -inline std::vector<Dictionary> fromPython(PyObject* o) -{ - if (!PyList_Check(o)) - CSP_THROW(TypeError, "List of dictionaries conversion expected type list got " << Py_TYPE(o)->tp_name); - - Py_ssize_t size = PyList_GET_SIZE(o); - std::vector<Dictionary> out; - out.reserve(size); - - for (Py_ssize_t i = 0; i < size; ++i) - { - PyObject* item = PyList_GET_ITEM(o, i); - - // Skip None values like in Dictionary conversion - if (item == Py_None) - continue; - - out.emplace_back(fromPython<Dictionary>(item)); - } - - return out; -} - template<> inline std::vector<Dictionary::Data> fromPython( PyObject * o ) { diff --git a/csp/adapters/websocket_types.py b/csp/adapters/websocket_types.py index 27eb0ecbd..542a1c205 100644 --- a/csp/adapters/websocket_types.py +++ b/csp/adapters/websocket_types.py @@ -29,7 +29,7 @@ class WebsocketHeaderUpdate(Struct): class ConnectionRequest(Struct): uri: str action: ActionType = ActionType.CONNECT # Connect, Disconnect, Ping, etc - # Whetehr we maintain the connection + # Whether we maintain the connection persistent: bool = True # Only relevant for Connect requests reconnect_interval: timedelta = timedelta(seconds=2) on_connect_payload: str = "" # message to send on connect From 584de8a109e7c29959cd154e4d823d6a16671208 Mon Sep 17 00:00:00 2001 From: Nijat K <nijat.khanbabayev@gmail.com> Date: Thu, 9 Jan 2025 21:00:28 -0500 Subject: [PATCH 8/8] Fix wrong function call for ioc from boost beast Signed-off-by: Nijat K <nijat.khanbabayev@gmail.com> --- cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp index 32f78f3ca..d124f04f9 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp @@ -45,7 +45,7 @@ WebsocketEndpointManager::~WebsocketEndpointManager() } void WebsocketEndpointManager::start(DateTime starttime, DateTime endtime) { - m_ioc.reset(); + m_ioc.restart(); if( !m_dynamic ){ boost::asio::post(m_strand, [this]() { // We subscribe for both the subscribe and send calls