Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add buffer pooling #4

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ type Dialer struct {
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar

// WriteBufferPool is a Pool used to obtain write buffers.
// If a non-nil value is supplied, the WriteBufferSize field is ignored.
// If WriteBufferPool is nil and WriteBufferSize is not set, an internal
// pool will be used. Otherwise, new buffers will be allocated.
// If Get returns a value that is not a []byte with len > 0,
// bad things will happen. You have been warned.
WriteBufferPool Pool

// BufReaderPool is a Pool used to obtain bufio.Readers.
// If a non-nil value is supplied, the ReadBufferSize field is ignored.
// If BufReaderPool is nil and ReadBufferSize is not set, an internal
// pool will be used. Otherwise, new bufio.Readers will be allocated.
// If Get returns a value that is not a *bufio.Reader,
// bad things will happen. You have been warned.
BufReaderPool Pool
}

var errMalformedURL = errors.New("malformed ws or wss URL")
Expand Down Expand Up @@ -275,7 +291,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
}
}

conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
conn := newConn(netConn, false, d.getIOBuf())

if err := req.Write(netConn); err != nil {
return nil, nil, err
Expand Down
9 changes: 6 additions & 3 deletions compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ func textMessages(num int) [][]byte {

func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, (&Upgrader{}).getIOBuf())
defer c.Close()
messages := textMessages(100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
Expand All @@ -54,7 +55,8 @@ func BenchmarkWriteNoCompression(b *testing.B) {

func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, (&Upgrader{}).getIOBuf())
defer c.Close()
messages := textMessages(100)
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
Expand All @@ -66,7 +68,8 @@ func BenchmarkWriteWithCompression(b *testing.B) {
}

func TestValidCompressionLevel(t *testing.T) {
c := newConn(fakeNetConn{}, false, 1024, 1024)
c := newConn(fakeNetConn{}, false, (&Upgrader{}).getIOBuf())
defer c.Close()
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
if err := c.SetCompressionLevel(level); err == nil {
t.Errorf("no error for level %d", level)
Expand Down
57 changes: 7 additions & 50 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package websocket

import (
"bufio"
"encoding/binary"
"errors"
"io"
Expand Down Expand Up @@ -230,9 +229,11 @@ type Conn struct {
isServer bool
subprotocol string

// Reusable I/O fields
ioBuf

// Write fields
mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer.
writeDeadline time.Time
writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection
Expand All @@ -247,7 +248,6 @@ type Conn struct {
// Read fields
reader io.ReadCloser // the current reader returned to the application
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
readFinal bool // true the current message has more frames.
readLength int64 // Message size.
Expand All @@ -264,10 +264,6 @@ type Conn struct {
newDecompressionReader func(io.Reader) io.ReadCloser
}

func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
}

// writeHook is an io.Writer that steals the buffer that it is called with.
type writeHook struct {
p []byte
Expand All @@ -278,60 +274,20 @@ func (wh *writeHook) Write(p []byte) (int, error) {
return len(p), nil
}

func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
func newConn(conn net.Conn, isServer bool, iob ioBuf) *Conn {
mu := make(chan bool, 1)
mu <- true

var br *bufio.Reader
if readBufferSize == 0 && brw != nil && brw.Reader != nil {
// Reuse the supplied bufio.Reader if the buffer has a useful size.
// This code assumes that peek on a reader returns
// bufio.Reader.buf[:0].
brw.Reader.Reset(conn)
if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
br = brw.Reader
}
}
if br == nil {
if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize
}
if readBufferSize < maxControlFramePayloadSize {
readBufferSize = maxControlFramePayloadSize
}
br = bufio.NewReaderSize(conn, readBufferSize)
}

var writeBuf []byte
if writeBufferSize == 0 && brw != nil && brw.Writer != nil {
// Use the bufio.Writer's buffer if the buffer has a useful size. This
// code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
brw.Writer.Reset(&wh)
brw.Writer.WriteByte(0)
brw.Flush()
if cap(wh.p) >= maxFrameHeaderSize+256 {
writeBuf = wh.p[:cap(wh.p)]
}
}

if writeBuf == nil {
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
}
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
}
iob.br.Reset(conn)

c := &Conn{
isServer: isServer,
br: br,
conn: conn,
mu: mu,
readFinal: true,
writeBuf: writeBuf,
enableWriteCompression: true,
compressionLevel: defaultCompressionLevel,
ioBuf: iob,
}
c.SetCloseHandler(nil)
c.SetPingHandler(nil)
Expand All @@ -347,6 +303,7 @@ func (c *Conn) Subprotocol() string {
// Close closes the underlying network connection without sending or waiting
// for a close message.
func (c *Conn) Close() error {
defer c.ioBuf.cleanup()
return c.conn.Close()
}

Expand Down
3 changes: 2 additions & 1 deletion conn_broadcast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ func (b *broadcastBench) makeConns(numConns int) {
conns := make([]*broadcastConn, numConns)

for i := 0; i < numConns; i++ {
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, (&Upgrader{}).getIOBuf())
defer c.Close()
if b.compression {
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
Expand Down
93 changes: 42 additions & 51 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ import (
"time"
)

func newTestConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
return newConn(conn, isServer, (&Upgrader{
ReadBufferSize: readBufferSize,
WriteBufferSize: writeBufferSize,
}).getIOBuf())
}

var _ net.Error = errWriteTimeout

type fakeNetConn struct {
Expand Down Expand Up @@ -82,8 +89,10 @@ func TestFraming(t *testing.T) {
for _, chunker := range readChunkers {

var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, (&Upgrader{}).getIOBuf())
defer wc.Close()
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, (&Upgrader{}).getIOBuf())
defer rc.Close()
if compress {
wc.newCompressionWriter = compressNoContextTakeover
rc.newDecompressionReader = decompressNoContextTakeover
Expand Down Expand Up @@ -143,8 +152,10 @@ func TestControl(t *testing.T) {
for _, isWriteControl := range []bool{true, false} {
name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, (&Upgrader{}).getIOBuf())
defer wc.Close()
rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, (&Upgrader{}).getIOBuf())
defer wc.Close()
if isWriteControl {
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
} else {
Expand Down Expand Up @@ -179,8 +190,10 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}

var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
wc := newTestConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
defer wc.Close()
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, (&Upgrader{}).getIOBuf())
defer rc.Close()

w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
Expand All @@ -206,8 +219,8 @@ func TestEOFWithinFrame(t *testing.T) {

for n := 0; ; n++ {
var b bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, (&Upgrader{}).getIOBuf())
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, (&Upgrader{}).getIOBuf())

w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize))
Expand Down Expand Up @@ -240,8 +253,10 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
const bufSize = 512

var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
wc := newTestConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
defer wc.Close()
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, (&Upgrader{}).getIOBuf())
defer rc.Close()

w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
Expand All @@ -261,7 +276,8 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
}

func TestWriteAfterMessageWriterClose(t *testing.T) {
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, (&Upgrader{}).getIOBuf())
defer wc.Close()
w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
if err := w.Close(); err != nil {
Expand Down Expand Up @@ -292,8 +308,10 @@ func TestReadLimit(t *testing.T) {
message := make([]byte, readLimit+1)

var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
wc := newTestConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
defer wc.Close()
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, (&Upgrader{}).getIOBuf())
defer rc.Close()
rc.SetReadLimit(readLimit)

// Send message at the limit with interleaved pong.
Expand Down Expand Up @@ -321,7 +339,8 @@ func TestReadLimit(t *testing.T) {
}

func TestAddrs(t *testing.T) {
c := newConn(&fakeNetConn{}, true, 1024, 1024)
c := newConn(&fakeNetConn{}, true, (&Upgrader{}).getIOBuf())
defer c.Close()
if c.LocalAddr() != localAddr {
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
}
Expand All @@ -333,7 +352,7 @@ func TestAddrs(t *testing.T) {
func TestUnderlyingConn(t *testing.T) {
var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024)
c := newConn(fc, true, (&Upgrader{}).getIOBuf())
ul := c.UnderlyingConn()
if ul != fc {
t.Fatalf("Underlying conn is not what it should be.")
Expand All @@ -347,8 +366,10 @@ func TestBufioReadBytes(t *testing.T) {
m[len(m)-1] = '\n'

var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
wc := newTestConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
defer wc.Close()
rc := newTestConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
defer rc.Close()

w, _ := wc.NextWriter(BinaryMessage)
w.Write(m)
Expand Down Expand Up @@ -423,7 +444,8 @@ func (w blockingWriter) Write(p []byte) (int, error) {

func TestConcurrentWritePanic(t *testing.T) {
w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, (&Upgrader{}).getIOBuf())
defer c.Close()
go func() {
c.WriteMessage(TextMessage, []byte{})
}()
Expand All @@ -449,7 +471,8 @@ func (r failingReader) Read(p []byte) (int, error) {
}

func TestFailedConnectionReadPanic(t *testing.T) {
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, (&Upgrader{}).getIOBuf())
defer c.Close()

defer func() {
if v := recover(); v != nil {
Expand All @@ -462,35 +485,3 @@ func TestFailedConnectionReadPanic(t *testing.T) {
}
t.Fatal("should not get here")
}

func TestBufioReuse(t *testing.T) {
brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
c := newConnBRW(nil, false, 0, 0, brw)

if c.br != brw.Reader {
t.Error("connection did not reuse bufio.Reader")
}

var wh writeHook
brw.Writer.Reset(&wh)
brw.WriteByte(0)
brw.Flush()
if &c.writeBuf[0] != &wh.p[0] {
t.Error("connection did not reuse bufio.Writer")
}

brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
c = newConnBRW(nil, false, 0, 0, brw)

if c.br == brw.Reader {
t.Error("connection used bufio.Reader with small size")
}

brw.Writer.Reset(&wh)
brw.WriteByte(0)
brw.Flush()
if &c.writeBuf[0] != &wh.p[0] {
t.Error("connection used bufio.Writer with small size")
}

}
Loading