diff --git a/.gitignore b/.gitignore index 20e1433f..f9d5346a 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,7 @@ io_test parse_test crypto_test status_code_test + +.idea +/debug +/release diff --git a/server_http.hpp b/server_http.hpp index 60cd760e..734fa434 100644 --- a/server_http.hpp +++ b/server_http.hpp @@ -44,26 +44,27 @@ namespace SimpleWeb { #endif namespace SimpleWeb { - template + template class Server; - template class ServerBase { protected: + class Connection; class Session; public: class Response : public std::enable_shared_from_this, public std::ostream { - friend class ServerBase; - friend class Server; + friend class ServerBase; + protected: asio::streambuf streambuf; std::shared_ptr session; long timeout_content; - Response(std::shared_ptr session, long timeout_content) noexcept : std::ostream(&streambuf), session(std::move(session)), timeout_content(timeout_content) {} + Response(std::shared_ptr session, long timeout_content) noexcept : std::ostream(&streambuf), session(std::move(session)), timeout_content(timeout_content) { } + private: template void write_header(const CaseInsensitiveMultimap &header, size_type size) { bool content_length_written = false; @@ -87,18 +88,12 @@ namespace SimpleWeb { return streambuf.size(); } + /// Use this function if you need to recursively send parts of a longer message void send(const std::function &callback = nullptr) noexcept { session->connection->set_timeout(timeout_content); auto self = this->shared_from_this(); // Keep Response instance alive through the following async_write - asio::async_write(*session->connection->socket, streambuf, [self, callback](const error_code &ec, std::size_t /*bytes_transferred*/) { - self->session->connection->cancel_timeout(); - auto lock = self->session->connection->handler_runner->continue_lock(); - if(!lock) - return; - if(callback) - callback(ec); - }); + async_write(callback); } /// Write directly to stream buffer using std::ostream::write @@ -151,10 +146,13 @@ namespace SimpleWeb { /// This is useful when implementing a HTTP/1.0-server sending content /// without specifying the content length. bool close_connection_after_response = false; + + private: + virtual void async_write(const std::function &callback) = 0; }; class Content : public std::istream { - friend class ServerBase; + friend class ServerBase; public: std::size_t size() noexcept { @@ -178,8 +176,7 @@ namespace SimpleWeb { }; class Request { - friend class ServerBase; - friend class Server; + friend class ServerBase; friend class Session; asio::streambuf streambuf; @@ -223,23 +220,26 @@ namespace SimpleWeb { protected: class Connection : public std::enable_shared_from_this { public: - template - Connection(std::shared_ptr handler_runner, Args &&... args) noexcept : handler_runner(std::move(handler_runner)), socket(new socket_type(std::forward(args)...)) {} + Connection(std::shared_ptr handler_runner) noexcept : handler_runner(std::move(handler_runner)) {} + virtual ~Connection() {}; std::shared_ptr handler_runner; - std::unique_ptr socket; // Socket must be unique_ptr since asio::ssl::stream is not movable + // std::unique_ptr socket; // Socket must be unique_ptr since asio::ssl::stream is not movable std::mutex socket_close_mutex; std::unique_ptr timer; std::shared_ptr remote_endpoint; + virtual asio::ip::tcp::socket::lowest_layer_type& lowest_layer() = 0; + virtual asio::io_service& get_io_service() = 0; + void close() noexcept { error_code ec; std::unique_lock lock(socket_close_mutex); // The following operations seems to be needed to run sequentially - socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec); - socket->lowest_layer().close(ec); + lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec); + lowest_layer().close(ec); } void set_timeout(long seconds) noexcept { @@ -248,7 +248,7 @@ namespace SimpleWeb { return; } - timer = std::unique_ptr(new asio::steady_timer(socket->get_io_service())); + timer = std::unique_ptr(new asio::steady_timer(get_io_service())); timer->expires_from_now(std::chrono::seconds(seconds)); auto self = this->shared_from_this(); timer->async_wait([self](const error_code &ec) { @@ -266,11 +266,12 @@ namespace SimpleWeb { }; class Session { + friend class ServerBase; public: Session(std::size_t max_request_streambuf_size, std::shared_ptr connection) noexcept : connection(std::move(connection)) { if(!this->connection->remote_endpoint) { error_code ec; - this->connection->remote_endpoint = std::make_shared(this->connection->socket->lowest_layer().remote_endpoint(ec)); + this->connection->remote_endpoint = std::make_shared(this->connection->lowest_layer().remote_endpoint(ec)); } request = std::shared_ptr(new Request(max_request_streambuf_size, this->connection->remote_endpoint)); } @@ -281,7 +282,7 @@ namespace SimpleWeb { public: class Config { - friend class ServerBase; + friend class ServerBase; Config(unsigned short port) noexcept : port(port) {} @@ -321,13 +322,12 @@ namespace SimpleWeb { public: /// Warning: do not add or remove resources after start() is called - std::map::Response>, std::shared_ptr::Request>)>>> resource; + std::map, std::shared_ptr)>>> resource; - std::map::Response>, std::shared_ptr::Request>)>> default_resource; + std::map, std::shared_ptr)>> default_resource; - std::function::Request>, const error_code &)> on_error; + std::function, const error_code &)> on_error; - std::function &, std::shared_ptr::Request>)> on_upgrade; /// If you have your own asio::io_service, store its pointer here before running start(). std::shared_ptr io_service; @@ -411,38 +411,28 @@ namespace SimpleWeb { ServerBase(unsigned short port) noexcept : config(port), connections(new std::unordered_set()), connections_mutex(new std::mutex()), handler_runner(new ScopeRunner()) {} + virtual Response* create_response(std::shared_ptr session, long timeout_content) = 0; virtual void accept() = 0; - template - std::shared_ptr create_connection(Args &&... args) noexcept { - auto connections = this->connections; - auto connections_mutex = this->connections_mutex; - auto connection = std::shared_ptr(new Connection(handler_runner, std::forward(args)...), [connections, connections_mutex](Connection *connection) { - { - std::unique_lock lock(*connections_mutex); - auto it = connections->find(connection); - if(it != connections->end()) - connections->erase(it); - } - delete connection; - }); - { - std::unique_lock lock(*connections_mutex); - connections->emplace(connection.get()); - } - return connection; - } + virtual void async_read_until(std::shared_ptr connection, asio::streambuf& streambuf, const std::string& eol, + std::function lambda) = 0; + virtual void async_read(std::shared_ptr connection, asio::streambuf& streambuf, std::size_t length, + std::function lambda) = 0; + + virtual void do_upgrade (const std::shared_ptr &session) = 0; - void read(const std::shared_ptr &session) { + void read(const std::shared_ptr &session) { session->connection->set_timeout(config.timeout_request); - asio::async_read_until(*session->connection->socket, session->request->streambuf, "\r\n\r\n", [this, session](const error_code &ec, std::size_t bytes_transferred) { + std::function lambda = + [this, session](const error_code &ec, size_t bytes_transferred) { + session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; session->request->header_read_time = std::chrono::system_clock::now(); if((!ec || ec == asio::error::not_found) && session->request->streambuf.size() == session->request->streambuf.max_size()) { - auto response = std::shared_ptr(new Response(session, this->config.timeout_content)); + auto response = std::shared_ptr(create_response(session, this->config.timeout_content)); response->write(StatusCode::client_error_payload_too_large); response->send(); if(this->on_error) @@ -477,14 +467,17 @@ namespace SimpleWeb { } if(content_length > num_additional_bytes) { session->connection->set_timeout(config.timeout_content); - asio::async_read(*session->connection->socket, session->request->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, session](const error_code &ec, std::size_t /*bytes_transferred*/) { + + std::function lambda2 = + [this, session](const error_code &ec, size_t /* bytes_transferred */) { + session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { if(session->request->streambuf.size() == session->request->streambuf.max_size()) { - auto response = std::shared_ptr(new Response(session, this->config.timeout_content)); + auto response = std::shared_ptr(create_response(session, this->config.timeout_content)); response->write(StatusCode::client_error_payload_too_large); response->send(); if(this->on_error) @@ -495,8 +488,11 @@ namespace SimpleWeb { } else if(this->on_error) this->on_error(session->request, ec); - }); - } + }; + + async_read(session->connection, session->request->streambuf, content_length - num_additional_bytes, lambda2); + + } else this->find_resource(session); } @@ -509,18 +505,24 @@ namespace SimpleWeb { } else if(this->on_error) this->on_error(session->request, ec); - }); - } + }; - void read_chunked_transfer_encoded(const std::shared_ptr &session, const std::shared_ptr &chunks_streambuf) { + async_read_until(session->connection, session->request->streambuf, "\r\n\r\n", lambda); + } + + void read_chunked_transfer_encoded(const std::shared_ptr &session, + const std::shared_ptr &chunks_streambuf) { session->connection->set_timeout(config.timeout_content); - asio::async_read_until(*session->connection->socket, session->request->streambuf, "\r\n", [this, session, chunks_streambuf](const error_code &ec, size_t bytes_transferred) { + + std::function lambda = + [this, session, chunks_streambuf](const error_code &ec, size_t bytes_transferred) { + session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; if((!ec || ec == asio::error::not_found) && session->request->streambuf.size() == session->request->streambuf.max_size()) { - auto response = std::shared_ptr(new Response(session, this->config.timeout_content)); + auto response = std::shared_ptr(create_response(session, this->config.timeout_content)); response->write(StatusCode::client_error_payload_too_large); response->send(); if(this->on_error) @@ -546,14 +548,17 @@ namespace SimpleWeb { if((2 + length) > num_additional_bytes) { session->connection->set_timeout(config.timeout_content); - asio::async_read(*session->connection->socket, session->request->streambuf, asio::transfer_exactly(2 + length - num_additional_bytes), [this, session, chunks_streambuf, length](const error_code &ec, size_t /*bytes_transferred*/) { + + std::function lambda2 = + [this, session, chunks_streambuf, length](const error_code &ec, size_t /* bytes_transferred */) { + session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { if(session->request->streambuf.size() == session->request->streambuf.max_size()) { - auto response = std::shared_ptr(new Response(session, this->config.timeout_content)); + auto response = std::shared_ptr(create_response(session, this->config.timeout_content)); response->write(StatusCode::client_error_payload_too_large); response->send(); if(this->on_error) @@ -564,14 +569,19 @@ namespace SimpleWeb { } else if(this->on_error) this->on_error(session->request, ec); - }); + }; + + async_read(session->connection, session->request->streambuf, 2 + length - num_additional_bytes, lambda2); + } else this->read_chunked_transfer_encoded_chunk(session, chunks_streambuf, length); } else if(this->on_error) this->on_error(session->request, ec); - }); + }; + + async_read_until(session->connection, session->request->streambuf, "\r\n", lambda); } void read_chunked_transfer_encoded_chunk(const std::shared_ptr &session, const std::shared_ptr &chunks_streambuf, unsigned long length) { @@ -581,7 +591,7 @@ namespace SimpleWeb { session->request->content.read(buffer.get(), static_cast(length)); tmp_stream.write(buffer.get(), static_cast(length)); if(chunks_streambuf->size() == chunks_streambuf->max_size()) { - auto response = std::shared_ptr(new Response(session, this->config.timeout_content)); + auto response = std::shared_ptr(create_response(session, this->config.timeout_content)); response->write(StatusCode::client_error_payload_too_large); response->send(); if(this->on_error) @@ -607,21 +617,8 @@ namespace SimpleWeb { void find_resource(const std::shared_ptr &session) { // Upgrade connection - if(on_upgrade) { - auto it = session->request->header.find("Upgrade"); - if(it != session->request->header.end()) { - // remove connection from connections - { - std::unique_lock lock(*connections_mutex); - auto it = connections->find(session->connection.get()); - if(it != connections->end()) - connections->erase(it); - } + do_upgrade(session); - on_upgrade(session->connection->socket, session->request); - return; - } - } // Find path- and method-match, and call write for(auto ®ex_method : resource) { auto it = regex_method.second.find(session->request->method); @@ -640,9 +637,9 @@ namespace SimpleWeb { } void write(const std::shared_ptr &session, - std::function::Response>, std::shared_ptr::Request>)> &resource_function) { + std::function, std::shared_ptr)> &resource_function) { session->connection->set_timeout(config.timeout_content); - auto response = std::shared_ptr(new Response(session, config.timeout_content), [this](Response *response_ptr) { + auto response = std::shared_ptr(create_response(session, config.timeout_content), [this](Response *response_ptr) { auto response = std::shared_ptr(response_ptr); response->send([this, response](const error_code &ec) { if(!ec) { @@ -681,17 +678,80 @@ namespace SimpleWeb { } }; - template - class Server : public ServerBase {}; + template + class Server : public ServerBase {}; using HTTP = asio::ip::tcp::socket; template <> - class Server : public ServerBase { + class Server : public ServerBase { public: - Server() noexcept : ServerBase::ServerBase(80) {} + Server() noexcept : ServerBase::ServerBase(80) {} + + std::function &, std::shared_ptr)> on_upgrade; + + private: + template + class HttpConnection: public Connection { + friend class Server; + friend class ServerBase; + + template + HttpConnection(std::shared_ptr handler_runner, Args &&... args) noexcept : Connection(handler_runner), socket(new HTTP(std::forward(args)...)) {} + + ~HttpConnection() {}; + + std::unique_ptr socket; // Socket must be unique_ptr since asio::ssl::stream is not movable + + asio::ip::tcp::socket::lowest_layer_type& lowest_layer() override { return socket->lowest_layer(); } + asio::io_service& get_io_service() override { return socket->get_io_service(); } + }; + + class HttpResponse: public Response { + friend class Server; + + HttpResponse(std::shared_ptr session, long timeout_content) noexcept : Response(session, timeout_content) {} + + void async_write(const std::function &callback) override { + auto session = this->session; + + asio::async_write(*static_cast*>(session->connection.get())->socket + , streambuf, + [session, callback](const error_code& ec, std::size_t /*bytes_transferred*/) { + session->connection->cancel_timeout(); + auto lock = session->connection->handler_runner->continue_lock(); + if (!lock) + return; + if (callback) + callback(ec); + }); + } + }; + + Response* create_response(std::shared_ptr session, long timeout_content) override { + return new HttpResponse(session, timeout_content); + } + + template + std::shared_ptr> create_connection(Args &&... args) noexcept { + auto connections = this->connections; + auto connections_mutex = this->connections_mutex; + auto connection = std::shared_ptr>(new HttpConnection(handler_runner, std::forward(args)...), [connections, connections_mutex](Connection *connection) { + { + std::unique_lock lock(*connections_mutex); + auto it = connections->find(connection); + if(it != connections->end()) + connections->erase(it); + } + delete connection; + }); + { + std::unique_lock lock(*connections_mutex); + connections->emplace(connection.get()); + } + return connection; + } - protected: void accept() override { auto connection = create_connection(*io_service); @@ -709,7 +769,7 @@ namespace SimpleWeb { if(!ec) { asio::ip::tcp::no_delay option(true); error_code ec; - session->connection->socket->set_option(option, ec); + connection->socket->set_option(option, ec); this->read(session); } @@ -717,6 +777,38 @@ namespace SimpleWeb { this->on_error(session->request, ec); }); } + + void async_read(std::shared_ptr connection, asio::streambuf& streambuf, std::size_t length, + std::function lambda) { + + asio::async_read(*static_cast*>(connection.get())->socket, + streambuf, asio::transfer_exactly(length), lambda); + } + + void async_read_until(std::shared_ptr connection, asio::streambuf& streambuf, const std::string& eol, + std::function lambda) override { + asio::async_read_until(*static_cast*>(connection.get())->socket, + streambuf, eol, lambda); + } + + void do_upgrade (const std::shared_ptr& session) { + if(on_upgrade) { + auto it = session->request->header.find("Upgrade"); + if(it != session->request->header.end()) { + // remove connection from connections + { + std::unique_lock lock(*connections_mutex); + auto it = connections->find(session->connection.get()); + if(it != connections->end()) + connections->erase(it); + } + + on_upgrade(static_cast*>(session->connection.get())->socket, session->request); + return; + } + } + + } }; } // namespace SimpleWeb diff --git a/server_https.hpp b/server_https.hpp index 4840b82c..9c2c4360 100644 --- a/server_https.hpp +++ b/server_https.hpp @@ -16,13 +16,13 @@ namespace SimpleWeb { using HTTPS = asio::ssl::stream; template <> - class Server : public ServerBase { + class Server : public ServerBase { std::string session_id_context; bool set_session_id_context = false; public: Server(const std::string &cert_file, const std::string &private_key_file, const std::string &verify_file = std::string()) - : ServerBase::ServerBase(443), context(asio::ssl::context::tlsv12) { + : ServerBase::ServerBase(443), context(asio::ssl::context::tlsv12) { context.use_certificate_chain_file(cert_file); context.use_private_key_file(private_key_file, asio::ssl::context::pem); @@ -47,10 +47,74 @@ namespace SimpleWeb { protected: asio::ssl::context context; + + private: + template + class HttpsConnection: public Connection { + friend class Server; + friend class ServerBase; + + template + HttpsConnection(std::shared_ptr handler_runner, Args &&... args) noexcept : Connection(handler_runner), socket(new HTTPS(std::forward(args)...)) {} + + ~HttpsConnection() {}; + + std::unique_ptr socket; // Socket must be unique_ptr since asio::ssl::stream is not movable + + asio::ip::tcp::socket::lowest_layer_type& lowest_layer() override { return socket->lowest_layer(); } + asio::io_service& get_io_service() override { return socket->get_io_service(); } + }; + + class HttpsResponse: public Response { + friend class Server; + + HttpsResponse(std::shared_ptr session, long timeout_content) noexcept : Response(session, timeout_content) {} + + void async_write(const std::function &callback) override { + auto session = this->session; + + asio::async_write(*static_cast*>(session->connection.get())->socket + , streambuf, + [session, callback](const error_code& ec, std::size_t /*bytes_transferred*/) { + session->connection->cancel_timeout(); + auto lock = session->connection->handler_runner->continue_lock(); + if (!lock) + return; + if (callback) + callback(ec); + }); + } + }; + + Response* create_response(std::shared_ptr session, long timeout_content) override { + return new HttpsResponse(session, timeout_content); + } + + template + std::shared_ptr> create_connection(Args &&... args) noexcept { + auto connections = this->connections; + auto connections_mutex = this->connections_mutex; + auto connection = std::shared_ptr>(new HttpsConnection(handler_runner, std::forward(args)...), [connections, connections_mutex](Connection *connection) { + { + std::unique_lock lock(*connections_mutex); + auto it = connections->find(connection); + if(it != connections->end()) + connections->erase(it); + } + delete connection; + }); + { + std::unique_lock lock(*connections_mutex); + connections->emplace(connection.get()); + } + return connection; + } + + void accept() override { auto connection = create_connection(*io_service, context); - acceptor->async_accept(connection->socket->lowest_layer(), [this, connection](const error_code &ec) { + acceptor->async_accept(connection->lowest_layer(), [this, connection](const error_code &ec) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; @@ -63,10 +127,10 @@ namespace SimpleWeb { if(!ec) { asio::ip::tcp::no_delay option(true); error_code ec; - session->connection->socket->lowest_layer().set_option(option, ec); + session->connection->lowest_layer().set_option(option, ec); session->connection->set_timeout(config.timeout_request); - session->connection->socket->async_handshake(asio::ssl::stream_base::server, [this, session](const error_code &ec) { + static_cast*>(connection.get())->socket->async_handshake(asio::ssl::stream_base::server, [this, session](const error_code &ec) { session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) @@ -81,6 +145,40 @@ namespace SimpleWeb { this->on_error(session->request, ec); }); } + + std::function &, std::shared_ptr)> on_upgrade; + + void async_read(std::shared_ptr connection, asio::streambuf& streambuf, std::size_t length, + std::function lambda) { + + asio::async_read(*static_cast*>(connection.get())->socket, + streambuf, asio::transfer_exactly(length), lambda); + } + + void async_read_until(std::shared_ptr connection, asio::streambuf& streambuf, const std::string& eol, + std::function lambda) override { + asio::async_read_until(*static_cast*>(connection.get())->socket, + streambuf, eol, lambda); + } + + void do_upgrade (const std::shared_ptr& session) { + if(on_upgrade) { + auto it = session->request->header.find("Upgrade"); + if(it != session->request->header.end()) { + // remove connection from connections + { + std::unique_lock lock(*connections_mutex); + auto it = connections->find(session->connection.get()); + if(it != connections->end()) + connections->erase(it); + } + + on_upgrade(static_cast*>(session->connection.get())->socket, session->request); + return; + } + } + + } }; } // namespace SimpleWeb