diff --git a/src/Network/Socket.cpp b/src/Network/Socket.cpp index 57b92991..7cb5d721 100644 --- a/src/Network/Socket.cpp +++ b/src/Network/Socket.cpp @@ -88,6 +88,11 @@ void Socket::setOnRead(onReadCB cb) { } } +void Socket::setOnMultiRead(onMultiReadCB cb) { + LOCK_GUARD(_mtx_event); + _on_multi_read = std::move(cb); +} + void Socket::setOnErr(onErrCB cb) { LOCK_GUARD(_mtx_event); if (cb) { @@ -282,10 +287,10 @@ class MMsgBuffer { , _buffers(count) , _address(count) { for (auto i = 0u; i < count; ++i) { - auto &buf = _buffers[i]; - buf = BufferRaw::create(); + auto buf = BufferRaw::create(); buf->setCapacity(size); + _buffers[i] = buf; auto &mmsg = _mmsgs[i]; auto &addr = _address[i]; mmsg.msg_len = 0; @@ -318,14 +323,14 @@ class MMsgBuffer { auto &mmsg = _mmsgs[i]; nread += mmsg.msg_len; - auto &buf = _buffers[i]; + auto buf = static_pointer_cast(_buffers[i]); buf->setSize(mmsg.msg_len); buf->data()[mmsg.msg_len] = '\0'; } return nread; } - const BufferRaw::Ptr &getBuffer(size_t index) const { return _buffers[index]; } + const Buffer::Ptr &getBuffer(size_t index) const { return _buffers[index]; } const struct sockaddr_storage &getAddress(size_t index) const { return _address[index]; } @@ -344,7 +349,7 @@ class MMsgBuffer { private: std::vector _iovec; std::vector _mmsgs; - std::vector _buffers; + std::vector _buffers; std::vector _address; }; @@ -386,14 +391,20 @@ ssize_t Socket::onRead(const SockNum::Ptr &sock, const BufferRaw::Ptr &) noexcep } LOCK_GUARD(_mtx_event); - for (auto i = 0u; i < count; ++i) { - auto &buf = buffer.getBuffer(i); - auto &addr = buffer.getAddress(i); - try { - // 此处捕获异常,目的是防止数据未读尽,epoll边沿触发失效的问题 - _on_read(buf, (struct sockaddr *)&addr, sizeof addr); - } catch (std::exception &ex) { - ErrorL << "Exception occurred when emit on_read: " << ex.what(); + if (_on_multi_read) { + auto &buf = buffer.getBuffer(0); + auto &addr = buffer.getAddress(0); + _on_multi_read(&buf, &addr, count); + } else { + for (auto i = 0u; i < count; ++i) { + auto &buf = buffer.getBuffer(i); + auto &addr = buffer.getAddress(i); + try { + // 此处捕获异常,目的是防止数据未读尽,epoll边沿触发失效的问题 + _on_read(buf, (struct sockaddr *)&addr, sizeof addr); + } catch (std::exception &ex) { + ErrorL << "Exception occurred when emit on_read: " << ex.what(); + } } } } diff --git a/src/Network/Socket.h b/src/Network/Socket.h index d43549aa..a2bd4af2 100644 --- a/src/Network/Socket.h +++ b/src/Network/Socket.h @@ -283,6 +283,8 @@ class Socket : public std::enable_shared_from_this, public noncopyable, using Ptr = std::shared_ptr; //接收数据回调 using onReadCB = std::function; + using onMultiReadCB = std::function; + //发生错误回调 using onErrCB = std::function; //tcp监听接收到连接请求 @@ -352,6 +354,7 @@ class Socket : public std::enable_shared_from_this, public noncopyable, * @param cb 回调对象 */ void setOnRead(onReadCB cb); + void setOnMultiRead(onMultiReadCB cb); /** * 设置异常事件(包括eof等)回调 @@ -566,6 +569,7 @@ class Socket : public std::enable_shared_from_this, public noncopyable, onErrCB _on_err; //收到数据事件 onReadCB _on_read; + onMultiReadCB _on_multi_read; //socket缓存清空事件(可用于发送流速控制) onFlush _on_flush; //tcp监听收到accept请求事件