Skip to content

Commit

Permalink
Merge pull request #124 from lesismal/body_allocator
Browse files Browse the repository at this point in the history
websocket: add body allocator config
  • Loading branch information
lesismal authored Oct 9, 2021
2 parents f583f40 + e2dcdcd commit f0f634e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
2 changes: 1 addition & 1 deletion examples/fixedbufferpool/fixedbufferpool.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package fixedbufferpool

import (
"os"
Expand Down
13 changes: 13 additions & 0 deletions nbhttp/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ type Config struct {
TLSConfig *tls.Config
TLSAllocator tls.Allocator

BodyAllocator mempool.Allocator

Context context.Context
Cancel func()

Expand All @@ -137,6 +139,7 @@ type Engine struct {
MaxLoad int
MaxWebsocketFramePayloadSize int
ReleaseWebsocketPayload bool
BodyAllocator mempool.Allocator
CheckUtf8 func(data []byte) bool

_onOpen func(c *nbio.Conn)
Expand Down Expand Up @@ -376,6 +379,9 @@ func NewEngine(conf Config, v ...interface{}) *Engine {
if conf.MaxWebsocketFramePayloadSize <= 0 {
conf.MaxWebsocketFramePayloadSize = DefaultMaxWebsocketFramePayloadSize
}
if conf.BodyAllocator == nil {
conf.BodyAllocator = mempool.DefaultMemPool
}

var handler = conf.Handler
if handler == nil {
Expand Down Expand Up @@ -475,6 +481,8 @@ func NewEngine(conf Config, v ...interface{}) *Engine {
emptyRequest: (&http.Request{}).WithContext(baseCtx),
BaseCtx: baseCtx,
Cancel: cancel,

BodyAllocator: conf.BodyAllocator,
}
if conf.SupportClient {
engine.InitTLSBuffers()
Expand Down Expand Up @@ -560,6 +568,9 @@ func NewEngineTLS(conf Config, v ...interface{}) *Engine {
if conf.TLSAllocator == nil {
conf.TLSAllocator = mempool.DefaultMemPool
}
if conf.BodyAllocator == nil {
conf.BodyAllocator = mempool.DefaultMemPool
}
if conf.MaxWebsocketFramePayloadSize <= 0 {
conf.MaxWebsocketFramePayloadSize = DefaultMaxWebsocketFramePayloadSize
}
Expand Down Expand Up @@ -685,6 +696,8 @@ func NewEngineTLS(conf Config, v ...interface{}) *Engine {
emptyRequest: (&http.Request{}).WithContext(baseCtx),
BaseCtx: baseCtx,
Cancel: cancel,

BodyAllocator: conf.BodyAllocator,
}
engine.InitTLSBuffers()

Expand Down
26 changes: 12 additions & 14 deletions nbhttp/websocket/upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (u *Upgrader) OnMessage(h func(*Conn, MessageType, []byte)) {
if h != nil {
u.messageHandler = func(c *Conn, messageType MessageType, data []byte) {
if c.Engine.ReleaseWebsocketPayload {
defer mempool.Free(data)
defer c.Engine.BodyAllocator.Free(data)
}
h(c, messageType, data)
}
Expand All @@ -137,7 +137,7 @@ func (u *Upgrader) OnDataFrame(h func(*Conn, MessageType, bool, []byte)) {
if h != nil {
u.dataFrameHandler = func(c *Conn, messageType MessageType, fin bool, data []byte) {
if c.Engine.ReleaseWebsocketPayload {
defer mempool.Free(data)
defer c.Engine.BodyAllocator.Free(data)
}
h(c, messageType, fin, data)
}
Expand Down Expand Up @@ -348,7 +348,7 @@ func (u *Upgrader) Read(p *nbhttp.Parser, data []byte) error {
if u.dataFrameHandler != nil {
var frame []byte
if bl > 0 {
frame = mempool.Malloc(bl)
frame = u.Engine.BodyAllocator.Malloc(bl)
copy(frame, body)
}
if u.opcode == TextMessage && len(frame) > 0 && !u.Engine.CheckUtf8(frame) {
Expand All @@ -359,7 +359,7 @@ func (u *Upgrader) Read(p *nbhttp.Parser, data []byte) error {
}
if bl > 0 && u.messageHandler != nil {
if u.message == nil {
u.message = mempool.Malloc(len(body))
u.message = u.Engine.BodyAllocator.Malloc(len(body))
copy(u.message, body)
} else {
u.message = append(u.message, body...)
Expand All @@ -370,8 +370,8 @@ func (u *Upgrader) Read(p *nbhttp.Parser, data []byte) error {
if u.compress {
var b []byte
rc := decompressReader(io.MultiReader(bytes.NewBuffer(u.message), strings.NewReader(flateReaderTail)))
b, err = readAll(rc, len(u.message)*2)
mempool.Free(u.message)
b, err = u.readAll(rc, len(u.message)*2)
u.Engine.BodyAllocator.Free(u.message)
u.message = b
rc.Close()
if err != nil {
Expand All @@ -390,7 +390,7 @@ func (u *Upgrader) Read(p *nbhttp.Parser, data []byte) error {
} else {
var frame []byte
if len(body) > 0 {
frame = mempool.Malloc(len(body))
frame = u.Engine.BodyAllocator.Malloc(len(body))
copy(frame, body)
}
u.handleProtocolMessage(p, opcode, frame)
Expand Down Expand Up @@ -454,8 +454,8 @@ func (u *Upgrader) handleMessage(p *nbhttp.Parser, opcode MessageType, body []by
func (u *Upgrader) handleProtocolMessage(p *nbhttp.Parser, opcode MessageType, body []byte) {
p.Execute(func() {
u.handleWsMessage(u.conn, opcode, body)
if len(body) > 0 {
mempool.Free(body)
if len(body) > 0 && u.Engine.ReleaseWebsocketPayload {
u.Engine.BodyAllocator.Free(body)
}
})
}
Expand Down Expand Up @@ -841,9 +841,9 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
return "", ""
}

func readAll(r io.Reader, size int) ([]byte, error) {
func (u *Upgrader) readAll(r io.Reader, size int) ([]byte, error) {
const maxAppendSize = 1024 * 1024 * 4
buf := mempool.Malloc(size)[0:0]
buf := u.Engine.BodyAllocator.Malloc(size)[0:0]
for {
n, err := r.Read(buf[len(buf):cap(buf)])
if n > 0 {
Expand All @@ -861,9 +861,7 @@ func readAll(r io.Reader, size int) ([]byte, error) {
if al > maxAppendSize {
al = maxAppendSize
}
tail := mempool.Malloc(al)
buf = append(buf, tail...)[:l]
mempool.Free(tail)
buf = append(buf, make([]byte, al)...)[:l]
}
}
}

0 comments on commit f0f634e

Please sign in to comment.