Skip to content

Commit

Permalink
feat(SSL): move SSLManager singleton and SSLContext into the Server a…
Browse files Browse the repository at this point in the history
…nd add logic to acceptNewConnection to create a ssl struct for the connection and pass it to the Connection object on initialisation
  • Loading branch information
552020 committed May 26, 2024
1 parent 9762a58 commit 58bd4cb
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 14 deletions.
27 changes: 26 additions & 1 deletion src/Connection.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "Connection.hpp"

Connection::Connection(struct pollfd &pollFd, Server &server)
Connection::Connection(struct pollfd &pollFd, Server &server, SSL *ssl)

{
(void)server;
Expand Down Expand Up @@ -28,6 +28,8 @@ Connection::Connection(struct pollfd &pollFd, Server &server)
_CGIHasTimedOut = false;
_CGIHasReadPipe = false;
_startTime = 0;
_ssl = ssl;
_isSSL = ssl != NULL;
}

Connection::Connection(const Connection &other)
Expand Down Expand Up @@ -61,6 +63,8 @@ Connection::Connection(const Connection &other)
_CGIHasReadPipe = other._CGIHasReadPipe;
_cgiOutputBuffer = other._cgiOutputBuffer;
_startTime = other._startTime;
_ssl = other._ssl;
_isSSL = other._isSSL;
// std::cout << "Connection object copied" << std::endl;
}

Expand Down Expand Up @@ -95,6 +99,8 @@ Connection &Connection::operator=(const Connection &other)
_CGIHasReadPipe = other._CGIHasReadPipe;
_cgiOutputBuffer = other._cgiOutputBuffer;
_startTime = other._startTime;
_ssl = other._ssl;
_isSSL = other._isSSL;
}
Debug::log("Connection object assigned", Debug::OCF);
return *this;
Expand Down Expand Up @@ -242,6 +248,16 @@ time_t Connection::getStartTime() const
return _startTime;
}

bool Connection::getIsSSL() const
{
return _isSSL;
}

SSL *Connection::getSSL() const
{
return _ssl;
}

// SETTERS

void Connection::setStartTime(time_t time)
Expand Down Expand Up @@ -348,6 +364,15 @@ void Connection::setCGIExitStatus(int status)
_CGIExitStatus = status;
}

void Connection::setIsSSL(bool value)
{
_isSSL = value;
}

void Connection::setSSL(SSL *ssl)
{
_ssl = ssl;
}
// METHODS

ssize_t Connection::readSocket(char *buffer, size_t bufferSize)
Expand Down
10 changes: 9 additions & 1 deletion src/Connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ class Connection
bool _CGIHasReadPipe;
std::string _cgiOutputBuffer;

/* SSL */
SSL *_ssl;
bool _isSSL;

public:
Connection(struct pollfd &pollFd, Server &server);
Connection(struct pollfd &pollFd, Server &server, SSL *ssl = NULL);
Connection(const Connection &other);
Connection &operator=(const Connection &other);
~Connection();
Expand Down Expand Up @@ -105,6 +109,8 @@ class Connection
bool getCGIHasTimedOut() const;
bool getCGIHasReadPipe() const;
std::string getCGIOutputBuffer() const;
SSL *getSSL() const;
bool getIsSSL() const;

/* Setters */
void setResponseString(std::string responseString);
Expand All @@ -131,6 +137,8 @@ class Connection
void setCGIHasTimedOut(bool value);
void setCGIHasReadPipe(bool value);
void setCGIOutputBuffer(std::string output);
void setSSL(SSL *ssl);
void setIsSSL(bool isSSL);
/* CGI */
void addCGI(pid_t pid);
void removeCGI(int status);
Expand Down
52 changes: 44 additions & 8 deletions src/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "EventManager.hpp"
#include "signal.h"

#define SSL_PORT 8443

Server::Server(const Config &config, EventManager &eventManager) : _config(config), _eventManager(eventManager)
{
_maxClients = 10;
Expand All @@ -15,6 +17,8 @@ Server::Server(const Config &config, EventManager &eventManager) : _config(confi
_hasCGI = false;
_CGICounter = 0;
_clientCounter = 0;
_sslManager = SSLManager::getInstance();
_sslContext = SSLContext();
Debug::log("Server created with config constructor", Debug::OCF);
}

Expand Down Expand Up @@ -481,7 +485,15 @@ void Server::closeClientConnection(Connection &conn, size_t &i)
<< "Entering closeClientConnection"
<< "\033[0m" << std::endl;
// TODO: should we close it with the Destructor of the Connection class?
if (conn.getIsSSL() && conn.getSSL() != NULL)
{
SSL_shutdown(conn.getSSL());
SSL_free(conn.getSSL());
conn.setSSL(NULL);
conn.setIsSSL(false);
}
close(conn.getPollFd().fd);

_FDs.erase(_FDs.begin() + i);
_connections.erase(_connections.begin() + i);
_connectionsPerIP[conn.getServerIp()] -= 1;
Expand Down Expand Up @@ -719,28 +731,52 @@ void Server::addServerSocketsPollFdToVectors()
}
}

void Server::acceptNewConnection(Connection &conn)
void Server::acceptNewConnection(Connection &serverConn)
{

Debug::log("SERVER SOCKET EVENT", Debug::SERVER, CYAN, false, true);
struct sockaddr_storage clientAddress;
socklen_t ClientAddrLen = sizeof(clientAddress);
int newSocket = accept(conn.getPollFd().fd, (struct sockaddr *)&clientAddress, (socklen_t *)&ClientAddrLen);
int newSocket = accept(serverConn.getPollFd().fd, (struct sockaddr *)&clientAddress, (socklen_t *)&ClientAddrLen);
if (newSocket >= 0)
{
struct pollfd newSocketPoll;
newSocketPoll.fd = newSocket;
newSocketPoll.events = POLLIN;
newSocketPoll.revents = 0;
Connection newConnection(newSocketPoll, *this);
// Before we create a new connection object, we set up the connection as SSL or not
SSL *ssl = NULL;
if (serverConn.getServerPort() == SSL_PORT)
{
// We create a new SSL object
ssl = SSL_new(_sslContext.getContext());
if (ssl == NULL)
{
Debug::log("Error creating SSL object", Debug::SERVER);
perror("In SSL_new");
close(newSocket);
return;
}
SSL_set_fd(ssl, newSocket);
if (SSL_accept(ssl) <= 0)
{
Debug::log("Error accepting SSL connection", Debug::SERVER);
perror("In SSL_accept");
SSL_free(ssl);
close(newSocket);
return;
}
// We set the SSL object to the connection
}
Connection newConnection(newSocketPoll, *this, ssl);
newConnection.setType(CLIENT);
newConnection.setServerIp(conn.getServerIp());
if (_connectionsPerIP.find(conn.getServerIp()) == _connectionsPerIP.end())
_connectionsPerIP.insert(std::pair<std::string, int>(conn.getServerIp(), 1));
newConnection.setServerIp(serverConn.getServerIp());
if (_connectionsPerIP.find(serverConn.getServerIp()) == _connectionsPerIP.end())
_connectionsPerIP.insert(std::pair<std::string, int>(serverConn.getServerIp(), 1));
else
_connectionsPerIP[conn.getServerIp()] += 1;
_connectionsPerIP[serverConn.getServerIp()] += 1;

newConnection.setServerPort(conn.getServerPort());
newConnection.setServerPort(serverConn.getServerPort());
/* start together */
_FDs.push_back(newSocketPoll);
_connections.push_back(newConnection);
Expand Down
7 changes: 7 additions & 0 deletions src/Server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "Parser.hpp"
#include "Config.hpp"
#include "ServerSocket.hpp"
#include "SSLContext.hpp"
#include "SSLManager.hpp"
#include "EventManager.hpp"

class Connection; // Forward declaration for circular dependencyA
Expand Down Expand Up @@ -65,6 +67,11 @@ class Server
int _CGICounter;
// number of connections per IP
std::map<std::string, int> _connectionsPerIP;
// SSL manager and context
// It is a pointer cause we want to use the singleton pattern
SSLManager *_sslManager;
// Normal object
SSLContext _sslContext;

/*** Private Methods ***/
Server();
Expand Down
6 changes: 2 additions & 4 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ int main(int argc, char **argv)
eventManager.subscribe(&serverEventListener);

// Initialize SSLManager and SSLContext
SSLManager *sslManager = SSLManager::getInstance();
(void)sslManager;
SSLContext sslContext;
(void)sslContext;
// SSLManager *sslManager = SSLManager::getInstance();
// SSLContext sslContext;

webserv.startListening();
webserv.startPollEventLoop();
Expand Down

0 comments on commit 58bd4cb

Please sign in to comment.