Skip to content

Commit

Permalink
AsyncWebSocket: Use shared_ptr for buffers
Browse files Browse the repository at this point in the history
Replace the hand-rolled buffer management logic with std::shared_ptr and
a basic RAII buffer class.  This simplifies memory management and seems
to fix a memory leak.
  • Loading branch information
willmmiles committed Feb 17, 2024
1 parent bb1982a commit 9311a64
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 366 deletions.
222 changes: 34 additions & 188 deletions src/AsyncWebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,122 +109,6 @@ size_t webSocketSendFrame(AsyncClient *client, bool final, uint8_t opcode, bool
}


/*
* AsyncWebSocketMessageBuffer
*/



AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer()
:_data(nullptr)
,_len(0)
,_lock(false)
,_count(0)
{

}

AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(uint8_t * data, size_t size)
:_data(nullptr)
,_len(size)
,_lock(false)
,_count(0)
{

if (!data) {
return;
}

_data = new uint8_t[_len + 1];

if (_data) {
memcpy(_data, data, _len);
_data[_len] = 0;
}
}


AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(size_t size)
:_data(nullptr)
,_len(size)
,_lock(false)
,_count(0)
{
_data = new uint8_t[_len + 1];

if (_data) {
_data[_len] = 0;
}

}

AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(const AsyncWebSocketMessageBuffer & copy)
:_data(nullptr)
,_len(0)
,_lock(false)
,_count(0)
{
_len = copy._len;
_lock = copy._lock;
_count = 0;

if (_len) {
_data = new uint8_t[_len + 1];
_data[_len] = 0;
}

if (_data) {
memcpy(_data, copy._data, _len);
_data[_len] = 0;
}

}

AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(AsyncWebSocketMessageBuffer && copy)
:_data(nullptr)
,_len(0)
,_lock(false)
,_count(0)
{
_len = copy._len;
_lock = copy._lock;
_count = 0;

if (copy._data) {
_data = copy._data;
copy._data = nullptr;
}

}

AsyncWebSocketMessageBuffer::~AsyncWebSocketMessageBuffer()
{
if (_data) {
delete[] _data;
}
}

bool AsyncWebSocketMessageBuffer::reserve(size_t size)
{
_len = size;

if (_data) {
delete[] _data;
_data = nullptr;
}

_data = new uint8_t[_len + 1];

if (_data) {
_data[_len] = 0;
return true;
} else {
return false;
}

}



/*
* Control Frame
Expand Down Expand Up @@ -374,41 +258,31 @@ AsyncWebSocketBasicMessage::~AsyncWebSocketBasicMessage() {
*/


AsyncWebSocketMultiMessage::AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer * buffer, uint8_t opcode, bool mask)
:_len(0)
,_sent(0)
AsyncWebSocketMultiMessage::AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer buffer, uint8_t opcode, bool mask)
:_sent(0)
,_ack(0)
,_acked(0)
,_WSbuffer(nullptr)
,_WSbuffer(std::move(buffer))
{

_opcode = opcode & 0x07;
_mask = mask;

if (buffer) {
_WSbuffer = buffer;
(*_WSbuffer)++;
_data = buffer->get();
_len = buffer->length();
if (_WSbuffer) {
_status = WS_MSG_SENDING;
//ets_printf("M: %u\n", _len);
} else {
_status = WS_MSG_ERROR;
}

}


AsyncWebSocketMultiMessage::~AsyncWebSocketMultiMessage() {
if (_WSbuffer) {
(*_WSbuffer)--; // decreases the counter.
}
}

void AsyncWebSocketMultiMessage::ack(size_t len, uint32_t time) {
(void)time;
_acked += len;
if(_sent >= _len && _acked >= _ack){
if(_sent >= _WSbuffer->size() && _acked >= _ack){
_status = WS_MSG_SENT;
}
//ets_printf("A: %u\n", len);
Expand All @@ -419,17 +293,17 @@ AsyncWebSocketMultiMessage::~AsyncWebSocketMultiMessage() {
if(_acked < _ack){
return 0;
}
if(_sent == _len){
if(_sent == _WSbuffer->size()){
_status = WS_MSG_SENT;
return 0;
}
if(_sent > _len){
if(_sent > _WSbuffer->size()){
_status = WS_MSG_ERROR;
//ets_printf("E: %u > %u\n", _sent, _len);
return 0;
}

size_t toSend = _len - _sent;
size_t toSend = _WSbuffer->size() - _sent;
size_t window = webSocketSendFrameWindow(client);

if(window < toSend) {
Expand All @@ -441,8 +315,8 @@ AsyncWebSocketMultiMessage::~AsyncWebSocketMultiMessage() {

//ets_printf("W: %u %u\n", _sent - toSend, toSend);

bool final = (_sent == _len);
uint8_t* dPtr = (uint8_t*)(_data + (_sent - toSend));
bool final = (_sent == _WSbuffer->size());
uint8_t* dPtr = (uint8_t*)(_WSbuffer->data() + (_sent - toSend));
uint8_t opCode = (toSend && _sent == toSend)?_opcode:(uint8_t)WS_CONTINUATION;

size_t sent = webSocketSendFrame(client, final, opCode, _mask, dPtr, toSend);
Expand Down Expand Up @@ -512,7 +386,6 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){
if(len && !_messageQueue.isEmpty()){
_messageQueue.front()->ack(len, time);
}
_server->_cleanBuffers();
_runQueue();
}

Expand Down Expand Up @@ -869,11 +742,10 @@ void AsyncWebSocketClient::text(const __FlashStringHelper *data){
free(message);
}
}
void AsyncWebSocketClient::text(AsyncWebSocketMessageBuffer * buffer)
void AsyncWebSocketClient::text(AsyncWebSocketMessageBuffer buffer)
{
_queueMessage(new AsyncWebSocketMultiMessage(buffer));
_queueMessage(new AsyncWebSocketMultiMessage(std::move(buffer)));
}

void AsyncWebSocketClient::binary(const char * message, size_t len){
_queueMessage(new AsyncWebSocketBasicMessage(message, len, WS_BINARY));
}
Expand All @@ -900,9 +772,9 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len){
}

}
void AsyncWebSocketClient::binary(AsyncWebSocketMessageBuffer * buffer)
void AsyncWebSocketClient::binary(AsyncWebSocketMessageBuffer buffer)
{
_queueMessage(new AsyncWebSocketMultiMessage(buffer, WS_BINARY));
_queueMessage(new AsyncWebSocketMultiMessage(std::move(buffer), WS_BINARY));
}

IPAddress AsyncWebSocketClient::remoteIP() {
Expand Down Expand Up @@ -930,7 +802,6 @@ AsyncWebSocket::AsyncWebSocket(const String& url)
,_clients(LinkedList<AsyncWebSocketClient *>([](AsyncWebSocketClient *c){ delete c; }))
,_cNextId(1)
,_enabled(true)
,_buffers(LinkedList<AsyncWebSocketMessageBuffer *>([](AsyncWebSocketMessageBuffer *b){ delete b; }))
{
_eventHandler = NULL;
}
Expand Down Expand Up @@ -1023,22 +894,18 @@ void AsyncWebSocket::text(uint32_t id, const char * message, size_t len){
c->text(message, len);
}

void AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer * buffer){
void AsyncWebSocket::textAll(const AsyncWebSocketMessageBuffer& buffer){
if (!buffer) return;
buffer->lock();
for(const auto& c: _clients){
if(c->status() == WS_CONNECTED){
c->text(buffer);
}
}
buffer->unlock();
_cleanBuffers();
}


void AsyncWebSocket::textAll(const char * message, size_t len){
AsyncWebSocketMessageBuffer * WSBuffer = makeBuffer((uint8_t *)message, len);
textAll(WSBuffer);
textAll(makeBuffer((uint8_t *)message, len));
}

void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len){
Expand All @@ -1048,20 +915,15 @@ void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len){
}

void AsyncWebSocket::binaryAll(const char * message, size_t len){
AsyncWebSocketMessageBuffer * buffer = makeBuffer((uint8_t *)message, len);
binaryAll(buffer);
binaryAll(makeBuffer((uint8_t *)message, len));
}

void AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer * buffer)
void AsyncWebSocket::binaryAll(const AsyncWebSocketMessageBuffer &buffer)
{
if (!buffer) return;
buffer->lock();
for(const auto& c: _clients){
if(c->status() == WS_CONNECTED)
c->binary(buffer);
}
buffer->unlock();
_cleanBuffers();
}

void AsyncWebSocket::message(uint32_t id, AsyncWebSocketMessage *message){
Expand All @@ -1070,12 +932,11 @@ void AsyncWebSocket::message(uint32_t id, AsyncWebSocketMessage *message){
c->message(message);
}

void AsyncWebSocket::messageAll(AsyncWebSocketMultiMessage *message){
void AsyncWebSocket::messageAll(const AsyncWebSocketMultiMessage &message){
for(const auto& c: _clients){
if(c->status() == WS_CONNECTED)
c->message(message);
c->message(new AsyncWebSocketMultiMessage(message));
}
_cleanBuffers();
}

size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){
Expand All @@ -1101,13 +962,13 @@ size_t AsyncWebSocket::printfAll(const char *format, ...) {
va_end(arg);
delete[] temp;

AsyncWebSocketMessageBuffer * buffer = makeBuffer(len);
AsyncWebSocketMessageBuffer buffer = makeBuffer(len);
if (!buffer) {
return 0;
}

va_start(arg, format);
vsnprintf( (char *)buffer->get(), len + 1, format, arg);
vsnprintf( (char *)buffer->data(), len, format, arg);
va_end(arg);

textAll(buffer);
Expand Down Expand Up @@ -1139,13 +1000,13 @@ size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) {
va_end(arg);
delete[] temp;

AsyncWebSocketMessageBuffer * buffer = makeBuffer(len + 1);
AsyncWebSocketMessageBuffer buffer = makeBuffer(len);
if (!buffer) {
return 0;
}

va_start(arg, formatP);
vsnprintf_P((char *)buffer->get(), len + 1, formatP, arg);
vsnprintf_P((char *)buffer->data(), len, formatP, arg);
va_end(arg);

textAll(buffer);
Expand Down Expand Up @@ -1273,37 +1134,22 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request){
request->send(response);
}

AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(size_t size)
AsyncWebSocketMessageBuffer AsyncWebSocket::makeBuffer(size_t size)
{
AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(size);
if (buffer) {
AsyncWebLockGuard l(_lock);
_buffers.add(buffer);
AsyncWebSocketMessageBuffer buffer = std::make_shared<DynamicBuffer>(size);
if (buffer->size() == 0) {
buffer.reset();
}
return buffer;
return buffer;
}

AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(uint8_t * data, size_t size)
AsyncWebSocketMessageBuffer AsyncWebSocket::makeBuffer(const uint8_t * data, size_t size)
{
AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(data, size);

if (buffer) {
AsyncWebLockGuard l(_lock);
_buffers.add(buffer);
}

return buffer;
}

void AsyncWebSocket::_cleanBuffers()
{
AsyncWebLockGuard l(_lock);

for(AsyncWebSocketMessageBuffer * c: _buffers){
if(c && c->canDelete()){
_buffers.remove(c);
}
AsyncWebSocketMessageBuffer buffer = std::make_shared<DynamicBuffer>((const char*) data, size);
if (buffer->size() == 0) {
buffer.reset();
}
return buffer;
}

AsyncWebSocket::AsyncWebSocketClientLinkedList AsyncWebSocket::getClients() const {
Expand Down
Loading

0 comments on commit 9311a64

Please sign in to comment.