Skip to content

Commit

Permalink
非Linux平台使用recvfrom方案
Browse files Browse the repository at this point in the history
  • Loading branch information
xia-chu committed Jun 29, 2024
1 parent 55e7d7e commit 7d0ef47
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 112 deletions.
128 changes: 128 additions & 0 deletions src/Network/BufferSock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,4 +466,132 @@ BufferList::Ptr BufferList::create(List<std::pair<Buffer::Ptr, bool> > list, Sen
#endif
}

#if defined(__linux) || defined(__linux__)
class SocketRecvmmsgBuffer : public SocketRecvBuffer {
public:
SocketRecvmmsgBuffer(size_t count, size_t size)
: _size(size)
, _iovec(count)
, _mmsgs(count)
, _buffers(count)
, _address(count) {
for (auto i = 0u; i < count; ++i) {
auto buf = BufferRaw::create();
buf->setCapacity(size);

_buffers[i] = buf;
auto &mmsg = _mmsgs[i];
auto &addr = _address[i];
mmsg.msg_len = 0;
mmsg.msg_hdr.msg_name = &addr;
mmsg.msg_hdr.msg_namelen = sizeof(addr);
mmsg.msg_hdr.msg_iov = &_iovec[i];
mmsg.msg_hdr.msg_iov->iov_base = buf->data();
mmsg.msg_hdr.msg_iov->iov_len = buf->getCapacity() - 1;
mmsg.msg_hdr.msg_iovlen = 1;
mmsg.msg_hdr.msg_control = nullptr;
mmsg.msg_hdr.msg_controllen = 0;
mmsg.msg_hdr.msg_flags = 0;
}
}

ssize_t recvFromSocket(int fd, ssize_t &count) override {
for (auto i = 0; i < _last_count; ++i) {
auto &mmsg = _mmsgs[i];
mmsg.msg_hdr.msg_namelen = sizeof(struct sockaddr_storage);
auto &buf = _buffers[i];
if (!buf) {
auto raw = BufferRaw::create();
raw->setCapacity(_size);
buf = raw;
mmsg.msg_hdr.msg_iov->iov_base = buf->data();
}
}
do {
count = recvmmsg(fd, &_mmsgs[0], _mmsgs.size(), 0, nullptr);
} while (-1 == count && UV_EINTR == get_uv_error(true));

_last_count = count;
if (count <= 0) {
return count;
}

ssize_t nread = 0;
for (auto i = 0; i < count; ++i) {
auto &mmsg = _mmsgs[i];
nread += mmsg.msg_len;

auto buf = static_pointer_cast<BufferRaw>(_buffers[i]);
buf->setSize(mmsg.msg_len);
buf->data()[mmsg.msg_len] = '\0';
}
return nread;
}

Buffer::Ptr &getBuffer(size_t index) override { return _buffers[index]; }

struct sockaddr_storage &getAddress(size_t index) override { return _address[index]; }

private:
size_t _size;
ssize_t _last_count { 0 };
std::vector<struct iovec> _iovec;
std::vector<struct mmsghdr> _mmsgs;
std::vector<Buffer::Ptr> _buffers;
std::vector<struct sockaddr_storage> _address;
};
#endif

class SocketRecvFromBuffer : public SocketRecvBuffer {
public:
SocketRecvFromBuffer(size_t size): _size(size) {}

ssize_t recvFromSocket(int fd, ssize_t &count) override {
ssize_t nread;
socklen_t len = sizeof(_address);
if (!_buffer) {
allocBuffer();
}

do {
nread = recvfrom(fd, _buffer->data(), _buffer->getCapacity() - 1, 0, (struct sockaddr *)&_address, &len);
} while (-1 == nread && UV_EINTR == get_uv_error(true));

if (nread > 0) {
count = 1;
_buffer->data()[nread] = '\0';
static_pointer_cast<BufferRaw>(_buffer)->setSize(nread);

Check failure on line 563 in src/Network/BufferSock.cpp

View workflow job for this annotation

GitHub Actions / build

'static_pointer_cast': undeclared identifier

Check failure on line 563 in src/Network/BufferSock.cpp

View workflow job for this annotation

GitHub Actions / build

'toolkit::BufferRaw': illegal use of this type as an expression

Check failure on line 563 in src/Network/BufferSock.cpp

View workflow job for this annotation

GitHub Actions / build

'setSize': is not a member of 'toolkit::Buffer'
}
return nread;
}

Buffer::Ptr &getBuffer(size_t index) override { return _buffer; }

struct sockaddr_storage &getAddress(size_t index) override { return _address; }

private:
void allocBuffer() {
auto buf = BufferRaw::create();
buf->setCapacity(_size);
_buffer = std::move(buf);
}

private:
size_t _size;
Buffer::Ptr _buffer;
struct sockaddr_storage _address;
};

static constexpr auto kPacketCount = 32;
static constexpr auto kBufferCapacity = 4 * 1024u;

static SocketRecvBuffer::Ptr create(bool is_udp) {
#if defined(__linux) || defined(__linux__)
if (is_udp) {
return std::make_shared<SocketRecvmmsgBuffer>(kPacketCount, kBufferCapacity);
}
#endif
return std::make_shared<SocketRecvFromBuffer>(kPacketCount * kBufferCapacity);
}

} //toolkit
13 changes: 13 additions & 0 deletions src/Network/BufferSock.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,18 @@ class BufferList : public noncopyable {
ObjectStatistic<BufferList> _statistic;
};

class SocketRecvBuffer {
public:
using Ptr = std::shared_ptr<SocketRecvBuffer>;

virtual ~SocketRecvBuffer() = default;

virtual ssize_t recvFromSocket(int fd, ssize_t &count) = 0;
virtual Buffer::Ptr &getBuffer(size_t index) = 0;
virtual struct sockaddr_storage &getAddress(size_t index) = 0;

static Ptr create(bool is_udp);
};

}
#endif //ZLTOOLKIT_BUFFERSOCK_H
108 changes: 5 additions & 103 deletions src/Network/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ bool Socket::attachEvent(const SockNum::Ptr &sock) {
}

// tcp客户端或udp
auto read_buffer = _poller->getSharedBuffer();
auto read_buffer = _poller->getSharedBuffer(sock->type() == SockNum::Sock_UDP);
auto result = _poller->addEvent(sock->rawFd(), EventPoller::Event_Read | EventPoller::Event_Error | EventPoller::Event_Write, [weak_self, sock, read_buffer](int event) {
auto strong_self = weak_self.lock();
if (!strong_self) {
Expand All @@ -283,109 +283,11 @@ bool Socket::attachEvent(const SockNum::Ptr &sock) {
return -1 != result;
}

class MMsgBuffer {
public:
#if !defined(__linux)
struct mmsghdr {
struct msghdr msg_hdr; /* Message header */
unsigned int msg_len; /* Number of received bytes for header */
};
#endif

MMsgBuffer(size_t count, size_t size)
: _size(size)
, _iovec(count)
, _mmsgs(count)
, _buffers(count)
, _address(count) {
for (auto i = 0u; i < count; ++i) {
auto buf = BufferRaw::create();
buf->setCapacity(size);

_buffers[i] = buf;
auto &mmsg = _mmsgs[i];
auto &addr = _address[i];
mmsg.msg_len = 0;
mmsg.msg_hdr.msg_name = &addr;
mmsg.msg_hdr.msg_namelen = sizeof(addr);
mmsg.msg_hdr.msg_iov = &_iovec[i];
mmsg.msg_hdr.msg_iov->iov_base = buf->data();
mmsg.msg_hdr.msg_iov->iov_len = buf->getCapacity() - 1;
mmsg.msg_hdr.msg_iovlen = 1;
mmsg.msg_hdr.msg_control = nullptr;
mmsg.msg_hdr.msg_controllen = 0;
mmsg.msg_hdr.msg_flags = 0;
}
}

ssize_t recvFromSocket(int fd, ssize_t &count) {
for (auto i = 0; i < _last_count; ++i) {
auto &mmsg = _mmsgs[i];
mmsg.msg_hdr.msg_namelen = sizeof(struct sockaddr_storage);
auto &buf = _buffers[i];
if (!buf) {
auto raw = BufferRaw::create();
raw->setCapacity(_size);
buf = raw;
mmsg.msg_hdr.msg_iov->iov_base = buf->data();
}
}
do {
count = recvmmsg(fd, &_mmsgs[0], _mmsgs.size(), 0, nullptr);
} while (-1 == count && UV_EINTR == get_uv_error(true));

_last_count = count;
if (count <= 0) {
return count;
}

ssize_t nread = 0;
for (auto i = 0; i < count; ++i) {
auto &mmsg = _mmsgs[i];
nread += mmsg.msg_len;

auto buf = static_pointer_cast<BufferRaw>(_buffers[i]);
buf->setSize(mmsg.msg_len);
buf->data()[mmsg.msg_len] = '\0';
}
return nread;
}

Buffer::Ptr &getBuffer(size_t index) { return _buffers[index]; }

struct sockaddr_storage &getAddress(size_t index) { return _address[index]; }

private:
#if !defined(__linux)
int recvmmsg(int sockfd, struct mmsghdr *msgvec, unsigned int, int flags, struct timespec *) {
auto sz = recvmsg(sockfd, &(msgvec->msg_hdr), flags);
if (sz > 0) {
msgvec->msg_len = sz;
return 1;
}
return sz;
}
#endif

private:
size_t _size;
ssize_t _last_count { 0 };
std::vector<struct iovec> _iovec;
std::vector<struct mmsghdr> _mmsgs;
std::vector<Buffer::Ptr> _buffers;
std::vector<struct sockaddr_storage> _address;
};

static constexpr auto kPacketSize = 32;
static constexpr auto kBufferCapacity = 4 * 1024u;

ssize_t Socket::onRead(const SockNum::Ptr &sock, const BufferRaw::Ptr &) noexcept {
static thread_local MMsgBuffer buffer(kPacketSize,kBufferCapacity);

ssize_t Socket::onRead(const SockNum::Ptr &sock, const SocketRecvBuffer::Ptr &buffer) noexcept {
ssize_t ret = 0, nread = 0, count = 0;

while (_enable_recv) {
nread = buffer.recvFromSocket(sock->rawFd(), count);
nread = buffer->recvFromSocket(sock->rawFd(), count);
if (nread == 0) {
if (sock->type() == SockNum::Sock_TCP) {
emitErr(SockException(Err_eof, "end of file"));
Expand Down Expand Up @@ -413,8 +315,8 @@ ssize_t Socket::onRead(const SockNum::Ptr &sock, const BufferRaw::Ptr &) noexcep
_recv_speed += nread;
}

auto &buf = buffer.getBuffer(0);
auto &addr = buffer.getAddress(0);
auto &buf = buffer->getBuffer(0);
auto &addr = buffer->getAddress(0);
try {
// 此处捕获异常,目的是防止数据未读尽,epoll边沿触发失效的问题
LOCK_GUARD(_mtx_event);
Expand Down
2 changes: 1 addition & 1 deletion src/Network/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ class Socket : public std::enable_shared_from_this<Socket>, public noncopyable,

void setSock(SockNum::Ptr sock);
int onAccept(const SockNum::Ptr &sock, int event) noexcept;
ssize_t onRead(const SockNum::Ptr &sock, const BufferRaw::Ptr &buffer) noexcept;
ssize_t onRead(const SockNum::Ptr &sock, const SocketRecvBuffer::Ptr &buffer) noexcept;
void onWriteAble(const SockNum::Ptr &sock);
void onConnected(const SockNum::Ptr &sock, const onErrCB &cb);
void onFlushed();
Expand Down
14 changes: 8 additions & 6 deletions src/Poller/EventPoller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,15 @@ inline void EventPoller::onPipeEvent() {
});
}

BufferRaw::Ptr EventPoller::getSharedBuffer() {
auto ret = _shared_buffer.lock();
SocketRecvBuffer::Ptr EventPoller::getSharedBuffer(bool is_udp) {
#if !defined(__linux) && !defined(__linux__)
// 非Linux平台下,tcp和udp共享recvfrom方案,使用同一个buffer
is_udp = 0;
#endif
auto ret = _shared_buffer[is_udp].lock();
if (!ret) {
//预留一个字节存放\0结尾符
ret = BufferRaw::create();
ret->setCapacity(1 + SOCKET_DEFAULT_BUF_SIZE);
_shared_buffer = ret;
ret = SocketRecvBuffer::create(is_udp);
_shared_buffer[is_udp] = ret;
}
return ret;
}
Expand Down
5 changes: 3 additions & 2 deletions src/Poller/EventPoller.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "Thread/TaskExecutor.h"
#include "Thread/ThreadPool.h"
#include "Network/Buffer.h"
#include "Network/BufferSock.h"

#if defined(__linux__) || defined(__linux)
#define HAS_EPOLL
Expand Down Expand Up @@ -123,7 +124,7 @@ class EventPoller : public TaskExecutor, public AnyStorage, public std::enable_s
/**
* 获取当前线程下所有socket共享的读缓存
*/
BufferRaw::Ptr getSharedBuffer();
SocketRecvBuffer::Ptr getSharedBuffer(bool is_udp);

/**
* 获取poller线程id
Expand Down Expand Up @@ -192,7 +193,7 @@ class EventPoller : public TaskExecutor, public AnyStorage, public std::enable_s
//线程名
std::string _name;
//当前线程下,所有socket共享的读缓存
std::weak_ptr<BufferRaw> _shared_buffer;
std::weak_ptr<SocketRecvBuffer> _shared_buffer[2];
//执行事件循环的线程
std::thread *_loop_thread = nullptr;
//通知事件循环的线程已启动
Expand Down

0 comments on commit 7d0ef47

Please sign in to comment.