From 2d0b9f06b219d8b605ba4db3efb1f179347be525 Mon Sep 17 00:00:00 2001 From: Jaden Weiss Date: Tue, 21 Aug 2018 11:57:56 -0400 Subject: [PATCH 1/2] add buffer pooling --- client.go | 18 ++++- compression_test.go | 9 ++- conn.go | 57 ++------------- conn_broadcast_test.go | 3 +- conn_test.go | 93 +++++++++++------------ json_test.go | 18 +++-- pool.go | 162 +++++++++++++++++++++++++++++++++++++++++ prepared.go | 5 +- prepared_test.go | 3 +- server.go | 18 ++++- 10 files changed, 271 insertions(+), 115 deletions(-) create mode 100644 pool.go diff --git a/client.go b/client.go index 070e182a..b7e550cd 100644 --- a/client.go +++ b/client.go @@ -82,6 +82,22 @@ type Dialer struct { // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar http.CookieJar + + // WriteBufferPool is a Pool used to obtain write buffers. + // If a non-nil value is supplied, the WriteBufferSize field is ignored. + // If WriteBufferPool is nil and WriteBufferSize is not set, an internal + // pool will be used. Otherwise, new buffers will be allocated. + // If Get returns a value that is not a []byte with len > 0, + // bad things will happen. You have been warned. + WriteBufferPool Pool + + // BufReaderPool is a Pool used to obtain bufio.Readers. + // If a non-nil value is supplied, the ReadBufferSize field is ignored. + // If BufReaderPool is nil and ReadBufferSize is not set, an internal + // pool will be used. Otherwise, new bufio.Readers will be allocated. + // If Get returns a value that is not a *bufio.Reader, + // bad things will happen. You have been warned. + BufReaderPool Pool } var errMalformedURL = errors.New("malformed ws or wss URL") @@ -275,7 +291,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } - conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) + conn := newConn(netConn, false, d.getIOBuf()) if err := req.Write(netConn); err != nil { return nil, nil, err diff --git a/compression_test.go b/compression_test.go index 659cf421..3b3f9826 100644 --- a/compression_test.go +++ b/compression_test.go @@ -43,7 +43,8 @@ func textMessages(num int) [][]byte { func BenchmarkWriteNoCompression(b *testing.B) { w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, (&Upgrader{}).getIOBuf()) + defer c.Close() messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -54,7 +55,8 @@ func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) { w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, (&Upgrader{}).getIOBuf()) + defer c.Close() messages := textMessages(100) c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover @@ -66,7 +68,8 @@ func BenchmarkWriteWithCompression(b *testing.B) { } func TestValidCompressionLevel(t *testing.T) { - c := newConn(fakeNetConn{}, false, 1024, 1024) + c := newConn(fakeNetConn{}, false, (&Upgrader{}).getIOBuf()) + defer c.Close() for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { if err := c.SetCompressionLevel(level); err == nil { t.Errorf("no error for level %d", level) diff --git a/conn.go b/conn.go index 59a4bafd..71de5e8c 100644 --- a/conn.go +++ b/conn.go @@ -5,7 +5,6 @@ package websocket import ( - "bufio" "encoding/binary" "errors" "io" @@ -230,9 +229,11 @@ type Conn struct { isServer bool subprotocol string + // Reusable I/O fields + ioBuf + // Write fields mu chan bool // used as mutex to protect write to conn - writeBuf []byte // frame is constructed in this buffer. writeDeadline time.Time writer io.WriteCloser // the current writer returned to the application isWriting bool // for best-effort concurrent write detection @@ -247,7 +248,6 @@ type Conn struct { // Read fields reader io.ReadCloser // the current reader returned to the application readErr error - br *bufio.Reader readRemaining int64 // bytes remaining in current frame. readFinal bool // true the current message has more frames. readLength int64 // Message size. @@ -264,10 +264,6 @@ type Conn struct { newDecompressionReader func(io.Reader) io.ReadCloser } -func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { - return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) -} - // writeHook is an io.Writer that steals the buffer that it is called with. type writeHook struct { p []byte @@ -278,60 +274,20 @@ func (wh *writeHook) Write(p []byte) (int, error) { return len(p), nil } -func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { +func newConn(conn net.Conn, isServer bool, iob ioBuf) *Conn { mu := make(chan bool, 1) mu <- true - var br *bufio.Reader - if readBufferSize == 0 && brw != nil && brw.Reader != nil { - // Reuse the supplied bufio.Reader if the buffer has a useful size. - // This code assumes that peek on a reader returns - // bufio.Reader.buf[:0]. - brw.Reader.Reset(conn) - if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 { - br = brw.Reader - } - } - if br == nil { - if readBufferSize == 0 { - readBufferSize = defaultReadBufferSize - } - if readBufferSize < maxControlFramePayloadSize { - readBufferSize = maxControlFramePayloadSize - } - br = bufio.NewReaderSize(conn, readBufferSize) - } - - var writeBuf []byte - if writeBufferSize == 0 && brw != nil && brw.Writer != nil { - // Use the bufio.Writer's buffer if the buffer has a useful size. This - // code assumes that bufio.Writer.buf[:1] is passed to the - // bufio.Writer's underlying writer. - var wh writeHook - brw.Writer.Reset(&wh) - brw.Writer.WriteByte(0) - brw.Flush() - if cap(wh.p) >= maxFrameHeaderSize+256 { - writeBuf = wh.p[:cap(wh.p)] - } - } - - if writeBuf == nil { - if writeBufferSize == 0 { - writeBufferSize = defaultWriteBufferSize - } - writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) - } + iob.br.Reset(conn) c := &Conn{ isServer: isServer, - br: br, conn: conn, mu: mu, readFinal: true, - writeBuf: writeBuf, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, + ioBuf: iob, } c.SetCloseHandler(nil) c.SetPingHandler(nil) @@ -347,6 +303,7 @@ func (c *Conn) Subprotocol() string { // Close closes the underlying network connection without sending or waiting // for a close message. func (c *Conn) Close() error { + defer c.ioBuf.cleanup() return c.conn.Close() } diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index 45038e48..31102381 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -70,7 +70,8 @@ func (b *broadcastBench) makeConns(numConns int) { conns := make([]*broadcastConn, numConns) for i := 0; i < numConns; i++ { - c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024) + c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, (&Upgrader{}).getIOBuf()) + defer c.Close() if b.compression { c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover diff --git a/conn_test.go b/conn_test.go index 5fda7b5c..20b0c64f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -18,6 +18,13 @@ import ( "time" ) +func newTestConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { + return newConn(conn, isServer, (&Upgrader{ + ReadBufferSize: readBufferSize, + WriteBufferSize: writeBufferSize, + }).getIOBuf()) +} + var _ net.Error = errWriteTimeout type fakeNetConn struct { @@ -82,8 +89,10 @@ func TestFraming(t *testing.T) { for _, chunker := range readChunkers { var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, (&Upgrader{}).getIOBuf()) + defer wc.Close() + rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, (&Upgrader{}).getIOBuf()) + defer rc.Close() if compress { wc.newCompressionWriter = compressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover @@ -143,8 +152,10 @@ func TestControl(t *testing.T) { for _, isWriteControl := range []bool{true, false} { name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024) + wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, (&Upgrader{}).getIOBuf()) + defer wc.Close() + rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, (&Upgrader{}).getIOBuf()) + defer wc.Close() if isWriteControl { wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { @@ -179,8 +190,10 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newTestConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) + defer wc.Close() + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, (&Upgrader{}).getIOBuf()) + defer rc.Close() w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) @@ -206,8 +219,8 @@ func TestEOFWithinFrame(t *testing.T) { for n := 0; ; n++ { var b bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024) - rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024) + wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, (&Upgrader{}).getIOBuf()) + rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, (&Upgrader{}).getIOBuf()) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize)) @@ -240,8 +253,10 @@ func TestEOFBeforeFinalFrame(t *testing.T) { const bufSize = 512 var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newTestConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) + defer wc.Close() + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, (&Upgrader{}).getIOBuf()) + defer rc.Close() w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) @@ -261,7 +276,8 @@ func TestEOFBeforeFinalFrame(t *testing.T) { } func TestWriteAfterMessageWriterClose(t *testing.T) { - wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024) + wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, (&Upgrader{}).getIOBuf()) + defer wc.Close() w, _ := wc.NextWriter(BinaryMessage) io.WriteString(w, "hello") if err := w.Close(); err != nil { @@ -292,8 +308,10 @@ func TestReadLimit(t *testing.T) { message := make([]byte, readLimit+1) var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newTestConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2) + defer wc.Close() + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, (&Upgrader{}).getIOBuf()) + defer rc.Close() rc.SetReadLimit(readLimit) // Send message at the limit with interleaved pong. @@ -321,7 +339,8 @@ func TestReadLimit(t *testing.T) { } func TestAddrs(t *testing.T) { - c := newConn(&fakeNetConn{}, true, 1024, 1024) + c := newConn(&fakeNetConn{}, true, (&Upgrader{}).getIOBuf()) + defer c.Close() if c.LocalAddr() != localAddr { t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) } @@ -333,7 +352,7 @@ func TestAddrs(t *testing.T) { func TestUnderlyingConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} - c := newConn(fc, true, 1024, 1024) + c := newConn(fc, true, (&Upgrader{}).getIOBuf()) ul := c.UnderlyingConn() if ul != fc { t.Fatalf("Underlying conn is not what it should be.") @@ -347,8 +366,10 @@ func TestBufioReadBytes(t *testing.T) { m[len(m)-1] = '\n' var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64) + wc := newTestConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64) + defer wc.Close() + rc := newTestConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64) + defer rc.Close() w, _ := wc.NextWriter(BinaryMessage) w.Write(m) @@ -423,7 +444,8 @@ func (w blockingWriter) Write(p []byte) (int, error) { func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, (&Upgrader{}).getIOBuf()) + defer c.Close() go func() { c.WriteMessage(TextMessage, []byte{}) }() @@ -449,7 +471,8 @@ func (r failingReader) Read(p []byte) (int, error) { } func TestFailedConnectionReadPanic(t *testing.T) { - c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024) + c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, (&Upgrader{}).getIOBuf()) + defer c.Close() defer func() { if v := recover(); v != nil { @@ -462,35 +485,3 @@ func TestFailedConnectionReadPanic(t *testing.T) { } t.Fatal("should not get here") } - -func TestBufioReuse(t *testing.T) { - brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil)) - c := newConnBRW(nil, false, 0, 0, brw) - - if c.br != brw.Reader { - t.Error("connection did not reuse bufio.Reader") - } - - var wh writeHook - brw.Writer.Reset(&wh) - brw.WriteByte(0) - brw.Flush() - if &c.writeBuf[0] != &wh.p[0] { - t.Error("connection did not reuse bufio.Writer") - } - - brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0)) - c = newConnBRW(nil, false, 0, 0, brw) - - if c.br == brw.Reader { - t.Error("connection used bufio.Reader with small size") - } - - brw.Writer.Reset(&wh) - brw.WriteByte(0) - brw.Flush() - if &c.writeBuf[0] != &wh.p[0] { - t.Error("connection used bufio.Writer with small size") - } - -} diff --git a/json_test.go b/json_test.go index 61100e48..0198b662 100644 --- a/json_test.go +++ b/json_test.go @@ -15,8 +15,10 @@ import ( func TestJSON(t *testing.T) { var buf bytes.Buffer c := fakeNetConn{&buf, &buf} - wc := newConn(c, true, 1024, 1024) - rc := newConn(c, false, 1024, 1024) + wc := newConn(c, true, (&Upgrader{}).getIOBuf()) + defer wc.Close() + rc := newConn(c, false, (&Upgrader{}).getIOBuf()) + defer rc.Close() var actual, expect struct { A int @@ -41,8 +43,10 @@ func TestJSON(t *testing.T) { func TestPartialJSONRead(t *testing.T) { var buf bytes.Buffer c := fakeNetConn{&buf, &buf} - wc := newConn(c, true, 1024, 1024) - rc := newConn(c, false, 1024, 1024) + wc := newConn(c, true, (&Upgrader{}).getIOBuf()) + defer wc.Close() + rc := newConn(c, false, (&Upgrader{}).getIOBuf()) + defer rc.Close() var v struct { A int @@ -95,8 +99,10 @@ func TestPartialJSONRead(t *testing.T) { func TestDeprecatedJSON(t *testing.T) { var buf bytes.Buffer c := fakeNetConn{&buf, &buf} - wc := newConn(c, true, 1024, 1024) - rc := newConn(c, false, 1024, 1024) + wc := newConn(c, true, (&Upgrader{}).getIOBuf()) + defer wc.Close() + rc := newConn(c, false, (&Upgrader{}).getIOBuf()) + defer rc.Close() var actual, expect struct { A int diff --git a/pool.go b/pool.go new file mode 100644 index 00000000..b9184013 --- /dev/null +++ b/pool.go @@ -0,0 +1,162 @@ +package websocket + +import ( + "bufio" + "errors" + "sync" +) + +// Pool is a general-purpose interface for pooling. +// Pool is implemented by sync.Pool. +type Pool interface { + // Get returns a new value. + // Must never return nil. + Get() interface{} + + // Put saves a value if possible. + Put(interface{}) +} + +// NewBufferPool creates a Pool of byte slices with len of size. +func NewBufferPool(size int) Pool { + return &sync.Pool{ + New: func() interface{} { + return make([]byte, size) + }, + } +} + +// defaultWriteBufPool is a default pool for write buffers. +var defaultWriteBufPool = NewBufferPool(defaultWriteBufferSize + maxFrameHeaderSize) + +// NewBufReaderPool creates a new Pool of *bufio.Reader. +func NewBufReaderPool(size int) Pool { + return &sync.Pool{ + New: func() interface{} { + return bufio.NewReaderSize(nil, size) + }, + } +} + +// defaultBufReaderPool is a default pool for bufio.Readers. +var defaultBufReaderPool = NewBufReaderPool(defaultReadBufferSize) + +// putFunc is a callback used to return buffers +type putFunc func(interface{}) + +// put puts the value if the putFunc is non-nil. +func (p putFunc) put(i interface{}) { + if p != nil { + p(i) + } +} + +// ioBuf is a set of I/O buffers and mechanisms to reuse them. +type ioBuf struct { + // br is the buffered reader used to read the message stream. + br *bufio.Reader + + // writeBuf is a write buffer used to construct messages. + writeBuf []byte + + // putBR and putWBuf are putFuncs used to recycle the buffers. + putBR, putWBuf putFunc +} + +// getIOBuf gets a set of I/O buffers, pooling based on the Upgrader settings. +func (u *Upgrader) getIOBuf() ioBuf { + var writeBuf []byte + var writeBufPut putFunc + switch { + case u.WriteBufferPool != nil: + writeBuf, writeBufPut = u.WriteBufferPool.Get().([]byte), u.WriteBufferPool.Put + case u.WriteBufferSize != 0 && u.WriteBufferSize != defaultWriteBufferSize: + writeBuf = make([]byte, u.WriteBufferSize+maxFrameHeaderSize) + default: + writeBuf, writeBufPut = defaultWriteBufPool.Get().([]byte), defaultWriteBufPool.Put + } + + var br *bufio.Reader + var brPut putFunc + switch { + case u.BufReaderPool != nil: + br, brPut = u.BufReaderPool.Get().(*bufio.Reader), u.BufReaderPool.Put + case u.ReadBufferSize != 0 && u.ReadBufferSize != defaultReadBufferSize: + br = bufio.NewReaderSize(nil, u.ReadBufferSize) + default: + br, brPut = defaultBufReaderPool.Get().(*bufio.Reader), defaultBufReaderPool.Put + } + + return ioBuf{ + br: br, + writeBuf: writeBuf, + putBR: brPut, + putWBuf: writeBufPut, + } +} + +// getIOBuf gets a set of I/O buffers, pooling based on the Dialer settings. +func (d *Dialer) getIOBuf() ioBuf { + var writeBuf []byte + var writeBufPut putFunc + switch { + case d.WriteBufferPool != nil: + writeBuf, writeBufPut = d.WriteBufferPool.Get().([]byte), d.WriteBufferPool.Put + case d.WriteBufferSize != 0 && d.WriteBufferSize != defaultWriteBufferSize: + writeBuf = make([]byte, d.WriteBufferSize+maxFrameHeaderSize) + default: + writeBuf, writeBufPut = defaultWriteBufPool.Get().([]byte), defaultWriteBufPool.Put + } + + var br *bufio.Reader + var brPut putFunc + switch { + case d.BufReaderPool != nil: + br, brPut = d.BufReaderPool.Get().(*bufio.Reader), d.BufReaderPool.Put + case d.ReadBufferSize != 0 && d.ReadBufferSize != defaultReadBufferSize: + br = bufio.NewReaderSize(nil, d.ReadBufferSize) + default: + br, brPut = defaultBufReaderPool.Get().(*bufio.Reader), defaultBufReaderPool.Put + } + + return ioBuf{ + br: br, + writeBuf: writeBuf, + putBR: brPut, + putWBuf: writeBufPut, + } +} + +// cleanup recycles the I/O buffers and invalidates the ioBuf. +func (iob *ioBuf) cleanup() { + // reset bufio.Reader to allow underlying conn to be garbage collected + if iob.br != nil { + iob.br.Reset(nil) + } + + // recycle bufio reader and write buffer + iob.putBR.put(iob.br) + iob.putWBuf.put(iob.writeBuf) + + // clear ioBuf to prevent reuse + *iob = ioBuf{} +} + +// chanPool is a channel-based pool implementation for testing purposes +type chanPool chan interface{} + +func (c chanPool) Put(i interface{}) { + select { + case c <- i: + default: + } +} + +func (c chanPool) Get() interface{} { + select { + case i := <-c: + return i + default: + panic(errors.New("no value")) + } +} diff --git a/prepared.go b/prepared.go index 1efffbd1..c9925538 100644 --- a/prepared.go +++ b/prepared.go @@ -83,8 +83,11 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { isServer: key.isServer, compressionLevel: key.compressionLevel, enableWriteCompression: true, - writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + ioBuf: ioBuf{ + writeBuf: defaultWriteBufPool.Get().([]byte), + }, } + defer c.ioBuf.cleanup() if key.compress { c.newCompressionWriter = compressNoContextTakeover } diff --git a/prepared_test.go b/prepared_test.go index cf98c6c1..3177b60f 100644 --- a/prepared_test.go +++ b/prepared_test.go @@ -36,7 +36,8 @@ func TestPreparedMessage(t *testing.T) { for _, tt := range preparedMessageTests { var data = []byte("this is a test") var buf bytes.Buffer - c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024) + c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, (&Upgrader{}).getIOBuf()) + defer c.Close() if tt.enableWriteCompression { c.newCompressionWriter = compressNoContextTakeover } diff --git a/server.go b/server.go index ca5cb4fe..8d4b5537 100644 --- a/server.go +++ b/server.go @@ -58,6 +58,22 @@ type Upgrader struct { // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool + + // WriteBufferPool is a Pool used to obtain write buffers. + // If a non-nil value is supplied, the WriteBufferSize field is ignored. + // If WriteBufferPool is nil and WriteBufferSize is not set, an internal + // pool will be used. Otherwise, new buffers will be allocated. + // If Get returns a value that is not a []byte with len > 0, + // bad things will happen. You have been warned. + WriteBufferPool Pool + + // BufReaderPool is a Pool used to obtain bufio.Readers. + // If a non-nil value is supplied, the ReadBufferSize field is ignored. + // If BufReaderPool is nil and ReadBufferSize is not set, an internal + // pool will be used. Otherwise, new bufio.Readers will be allocated. + // If Get returns a value that is not a *bufio.Reader, + // bad things will happen. You have been warned. + BufReaderPool Pool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -173,7 +189,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade return nil, errors.New("websocket: client sent data before handshake is complete") } - c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) + c := newConn(netConn, true, u.getIOBuf()) c.subprotocol = subprotocol if compress { From 774f2e3496c8e0c8d90088392d4b83cf0ebfd5b6 Mon Sep 17 00:00:00 2001 From: Jaden Weiss Date: Wed, 22 Aug 2018 07:44:55 -0400 Subject: [PATCH 2/2] add pooling tests --- pool.go | 4 ++ pool_test.go | 114 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) create mode 100644 pool_test.go diff --git a/pool.go b/pool.go index b9184013..2b100cfb 100644 --- a/pool.go +++ b/pool.go @@ -97,6 +97,10 @@ func (u *Upgrader) getIOBuf() ioBuf { // getIOBuf gets a set of I/O buffers, pooling based on the Dialer settings. func (d *Dialer) getIOBuf() ioBuf { + if d == nil { + d = &nilDialer + } + var writeBuf []byte var writeBufPut putFunc switch { diff --git a/pool_test.go b/pool_test.go new file mode 100644 index 00000000..0f58ab7e --- /dev/null +++ b/pool_test.go @@ -0,0 +1,114 @@ +package websocket + +import ( + "bufio" + "testing" +) + +func TestDialerGetIOBuf(t *testing.T) { + // prepare objects + wpool := make(chanPool, 1) + wbuf := make([]byte, defaultWriteBufferSize/4) + rpool := make(chanPool, 1) + br := bufio.NewReaderSize(nil, defaultReadBufferSize/4) + + // save buffers to pool + wpool.Put(wbuf) + rpool.Put(br) + + // get ioBuf using specific pools + chanPoolDialer := &Dialer{ + WriteBufferPool: wpool, + BufReaderPool: rpool, + } + iob := chanPoolDialer.getIOBuf() + if iob.br != br { + t.Errorf("Expected %T %p but got %p", br, br, iob.br) + } + if &iob.writeBuf[0] != &wbuf[0] { + t.Errorf("Expected %T %p but got %p", wbuf, wbuf, iob.writeBuf) + } + iob.cleanup() + if len(wpool) != 1 { + t.Error("Expected write buffer to be recycled, buffer not recycled") + } + if len(rpool) != 1 { + t.Error("Expected *bufio.Reader to be recycled, reader not recycled") + } + + // get ioBuf using size spec + sizeDialer := &Dialer{ + WriteBufferSize: 2 * defaultWriteBufferSize, + ReadBufferSize: 2 * defaultReadBufferSize, + } + iob = sizeDialer.getIOBuf() + if len(iob.writeBuf) != sizeDialer.WriteBufferSize+maxFrameHeaderSize { + t.Errorf("Expected write buffer with len %d but got len %d", sizeDialer.WriteBufferSize+maxFrameHeaderSize, len(iob.writeBuf)) + } + if iob.br.Size() != sizeDialer.ReadBufferSize { + t.Errorf("Expected buffered reader with size %d but got %d", sizeDialer.ReadBufferSize, iob.br.Size()) + } + + // get ioBuf using default dialer + iob = ((*Dialer)(nil)).getIOBuf() + if len(iob.writeBuf) != defaultWriteBufferSize+maxFrameHeaderSize { + t.Errorf("Expected write buffer with len %d but got len %d", defaultWriteBufferSize+maxFrameHeaderSize, len(iob.writeBuf)) + } + if iob.br.Size() != defaultReadBufferSize { + t.Errorf("Expected buffered reader with size %d but got %d", defaultReadBufferSize, iob.br.Size()) + } +} + +func TestUpgraderGetIOBuf(t *testing.T) { + // prepare objects + wpool := make(chanPool, 1) + wbuf := make([]byte, defaultWriteBufferSize/4) + rpool := make(chanPool, 1) + br := bufio.NewReaderSize(nil, defaultReadBufferSize/4) + + // save buffers to pool + wpool.Put(wbuf) + rpool.Put(br) + + // get ioBuf using specific pools + chanPoolUpgrader := &Upgrader{ + WriteBufferPool: wpool, + BufReaderPool: rpool, + } + iob := chanPoolUpgrader.getIOBuf() + if iob.br != br { + t.Errorf("Expected %T %p but got %p", br, br, iob.br) + } + if &iob.writeBuf[0] != &wbuf[0] { + t.Errorf("Expected %T %p but got %p", wbuf, wbuf, iob.writeBuf) + } + iob.cleanup() + if len(wpool) != 1 { + t.Error("Expected write buffer to be recycled, buffer not recycled") + } + if len(rpool) != 1 { + t.Error("Expected *bufio.Reader to be recycled, reader not recycled") + } + + // get ioBuf using size spec + sizeUpgrader := &Upgrader{ + WriteBufferSize: 2 * defaultWriteBufferSize, + ReadBufferSize: 2 * defaultReadBufferSize, + } + iob = sizeUpgrader.getIOBuf() + if len(iob.writeBuf) != sizeUpgrader.WriteBufferSize+maxFrameHeaderSize { + t.Errorf("Expected write buffer with len %d but got len %d", sizeUpgrader.WriteBufferSize+maxFrameHeaderSize, len(iob.writeBuf)) + } + if iob.br.Size() != sizeUpgrader.ReadBufferSize { + t.Errorf("Expected buffered reader with size %d but got %d", sizeUpgrader.ReadBufferSize, iob.br.Size()) + } + + // get ioBuf using default upgrader + iob = (&Upgrader{}).getIOBuf() + if len(iob.writeBuf) != defaultWriteBufferSize+maxFrameHeaderSize { + t.Errorf("Expected write buffer with len %d but got len %d", defaultWriteBufferSize+maxFrameHeaderSize, len(iob.writeBuf)) + } + if iob.br.Size() != defaultReadBufferSize { + t.Errorf("Expected buffered reader with size %d but got %d", defaultReadBufferSize, iob.br.Size()) + } +}