From 25b9c8fd1d607d10eecdc5399f9f0db80a876e90 Mon Sep 17 00:00:00 2001 From: Artur Troian Date: Mon, 28 Aug 2017 13:49:07 +0300 Subject: [PATCH] Optimize tcp read/write operations read: use goroutine with timer per each connection to handle keep alive insdead of using SetReadDeadline which requires call to time.Now. That leads to lots of syscalls on heavy loaded connections write: use net.Buffers to perform batch writes. buffers flushed every 1ms (tune this period if needed) Signed-off-by: Artur Troian --- Gopkg.lock | 8 +- buffer/buffer.go | 665 ------------------------------------ buffer/buffer_test.go | 269 --------------- connection/net.go | 184 ++++++---- connection/netCallbacks.go | 5 +- examples/surgemq/surgemq.go | 23 +- packet/connack_test.go | 4 +- packet/routines.go | 10 +- routines/misc.go | 6 - transport/base.go | 2 + 10 files changed, 140 insertions(+), 1036 deletions(-) delete mode 100644 buffer/buffer.go delete mode 100644 buffer/buffer_test.go diff --git a/Gopkg.lock b/Gopkg.lock index a4d877e..1d73817 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -103,6 +103,12 @@ revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" version = "v1.1.4" +[[projects]] + branch = "master" + name = "github.com/troian/goring" + packages = ["."] + revision = "f23b2d237abc4603ebeb2509c2cf1907debdf83f" + [[projects]] branch = "master" name = "github.com/troian/omap" @@ -154,6 +160,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "ff53b5f9fcbd74c1a6d7008101ed8b2f5145a7ffc9c6fc7e77ee659993d4a1f7" + inputs-digest = "79d21ebc8b01ced112de19100ca600001e9e62e84ceddc6efc2fe21118ac8a63" solver-name = "gps-cdcl" solver-version = 1 diff --git a/buffer/buffer.go b/buffer/buffer.go deleted file mode 100644 index fda2683..0000000 --- a/buffer/buffer.go +++ /dev/null @@ -1,665 +0,0 @@ -// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bufio" - "errors" - "fmt" - "io" - "sync" - "sync/atomic" -) - -var ( - bufCNT int64 -) - -var ( - // ErrInsufficientData buffer has insufficient data - ErrInsufficientData = errors.New("buffer has insufficient data") - - // ErrInsufficientSpace buffer has insufficient space - ErrInsufficientSpace = errors.New("buffer has insufficient space") - - // ErrNotReady buffer is not ready yet - ErrNotReady = errors.New("buffer is not ready") -) - -const ( - // DefaultBufferSize buffer size created by default - DefaultBufferSize = 1024 * 256 - // DefaultReadBlockSize default read block size - DefaultReadBlockSize = 1500 - // DefaultWriteBlockSize default write block size - DefaultWriteBlockSize = 1500 -) - -type sequence struct { - // The current position of the producer or consumer - cursor int64 - - // The previous known position of the consumer (if producer) or producer (if consumer) - gate int64 - - // These are fillers to pad the cache line, which is generally 64 bytes - //p2 int64 - //p3 int64 - //p4 int64 - //p5 int64 - //p6 int64 - //p7 int64 -} - -func newSequence() *sequence { - return &sequence{} -} - -func (b *sequence) get() int64 { - return atomic.LoadInt64(&b.cursor) -} - -func (b *sequence) set(seq int64) { - atomic.StoreInt64(&b.cursor, seq) -} - -// Type of buffer -// align atomic values to prevent panics on 32 bits macnines -// see https://github.com/golang/go/issues/5278 -type Type struct { - id int64 - size int64 - mask int64 - done int64 - - buf []byte - tmp []byte - ExternalBuf []byte - - pSeq *sequence - cSeq *sequence - pCond *sync.Cond - cCond *sync.Cond -} - -// New buffer -func New(size int64) (*Type, error) { - if size < 0 { - return nil, bufio.ErrNegativeCount - } - - if size == 0 { - size = DefaultBufferSize - } - - if !powerOfTwo64(size) { - return nil, fmt.Errorf("Size must be power of two. Try %d", roundUpPowerOfTwo64(size)) - } - - if size < 2*DefaultReadBlockSize { - return nil, fmt.Errorf("Size must at least be %d. Try %d", 2*DefaultReadBlockSize, 2*DefaultReadBlockSize) - } - - return &Type{ - id: atomic.AddInt64(&bufCNT, 1), - ExternalBuf: make([]byte, size), - buf: make([]byte, size), - size: size, - mask: size - 1, - pSeq: newSequence(), - cSeq: newSequence(), - pCond: sync.NewCond(new(sync.Mutex)), - cCond: sync.NewCond(new(sync.Mutex)), - }, nil -} - -// ID of buffer -func (b *Type) ID() int64 { - return b.id -} - -// Close buffer -func (b *Type) Close() error { - atomic.StoreInt64(&b.done, 1) - - b.pCond.L.Lock() - b.pCond.Broadcast() - b.pCond.L.Unlock() - - b.pCond.L.Lock() - b.cCond.Broadcast() - b.pCond.L.Unlock() - - return nil -} - -// Len of data -func (b *Type) Len() int { - cpos := b.cSeq.get() - ppos := b.pSeq.get() - return int(ppos - cpos) -} - -// Size of buffer -func (b *Type) Size() int64 { - return b.size -} - -// ReadFrom from reader -func (b *Type) ReadFrom(r io.Reader) (int64, error) { - total := int64(0) - - for { - if b.isDone() { - return total, io.EOF - } - - start, cnt, err := b.waitForWriteSpace(DefaultReadBlockSize) - if err != nil { - return 0, err - } - - pStart := start & b.mask - pEnd := pStart + int64(cnt) - if pEnd > b.size { - pEnd = b.size - } - - n, err := r.Read(b.buf[pStart:pEnd]) - if n > 0 { - total += int64(n) - if _, err = b.WriteCommit(n); err != nil { - return total, err - } - } - - if err != nil { - return total, err - } - } -} - -// WriteTo to writer -func (b *Type) WriteTo(w io.Writer) (int64, error) { - total := int64(0) - - for { - if b.isDone() { - return total, io.EOF - } - - p, err := b.ReadPeek(DefaultWriteBlockSize) - // There's some data, let's process it first - if len(p) > 0 { - var n int - n, err = w.Write(p) - total += int64(n) - - if err != nil { - return total, err - } - - _, err = b.ReadCommit(n) - if err != nil { - return total, err - } - } - - if err != nil { - if err != ErrInsufficientData { - return total, err - } - } - } -} - -// Read data -func (b *Type) Read(p []byte) (int, error) { - if b.isDone() && b.Len() == 0 { - return 0, io.EOF - } - - pl := int64(len(p)) - - for { - cPos := b.cSeq.get() - pPos := b.pSeq.get() - cIndex := cPos & b.mask - - // If consumer position is at least len(p) less than producer position, that means - // we have enough data to fill p. There are two scenarios that could happen: - // 1. cIndex + len(p) < buffer size, in this case, we can just copy() data from - // buffer to p, and copy will just copy enough to fill p and stop. - // The number of bytes copied will be len(p). - // 2. cIndex + len(p) > buffer size, this means the data will wrap around to the - // the beginning of the buffer. In thise case, we can also just copy data from - // buffer to p, and copy will just copy until the end of the buffer and stop. - // The number of bytes will NOT be len(p) but less than that. - if cPos+pl < pPos { - n := copy(p, b.buf[cIndex:]) - - b.cSeq.set(cPos + int64(n)) - b.pCond.L.Lock() - b.pCond.Broadcast() - b.pCond.L.Unlock() - - return n, nil - } - - // If we got here, that means there's not len(p) data available, but there might - // still be data. - - // If cPos < pPos, that means there's at least pPos-cPos bytes to read. Let's just - // send that back for now. - if cPos < pPos { - // n bytes available - avail := pPos - cPos - - // bytes copied - var n int - - // if cIndex+n < size, that means we can copy all n bytes into p. - // No wrapping in this case. - if cIndex+avail < b.size { - n = copy(p, b.buf[cIndex:cIndex+avail]) - } else { - // If cIndex+n >= size, that means we can copy to the end of buffer - n = copy(p, b.buf[cIndex:]) - } - - b.cSeq.set(cPos + int64(n)) - b.pCond.L.Lock() - b.pCond.Broadcast() - b.pCond.L.Unlock() - return n, nil - } - - // If we got here, that means cPos >= pPos, which means there's no data available. - // If so, let's wait... - - b.cCond.L.Lock() - for pPos = b.pSeq.get(); cPos >= pPos; pPos = b.pSeq.get() { - if b.isDone() { - b.cCond.L.Unlock() - return 0, io.EOF - } - - //b.cWait++ - b.cCond.Wait() - } - b.cCond.L.Unlock() - } -} - -// Write message -func (b *Type) Write(p []byte) (int, error) { - if b.isDone() { - return 0, io.EOF - } - - start, _, err := b.waitForWriteSpace(len(p)) - if err != nil { - return 0, err - } - - // If we are here that means we now have enough space to write the full p. - // Let's copy from p into this.buf, starting at position ppos&this.mask. - total := ringCopy(b.buf, p, start&b.mask) - - b.pSeq.set(start + int64(len(p))) - b.cCond.L.Lock() - b.cCond.Broadcast() - b.cCond.L.Unlock() - - return total, nil -} - -// ReadPeek Description below is copied completely from bufio.Peek() -// http://golang.org/pkg/bufio/#Reader.Peek -// Peek returns the next n bytes without advancing the reader. The bytes stop being valid -// at the next read call. If Peek returns fewer than n bytes, it also returns an error -// explaining why the read is short. The error is bufio.ErrBufferFull if n is larger than -// b's buffer size. -// If there's not enough data to peek, error is ErrBufferInsufficientData. -// If n < 0, error is bufio.ErrNegativeCount -func (b *Type) ReadPeek(n int) ([]byte, error) { - if int64(n) > b.size { - return nil, bufio.ErrBufferFull - } - - if n < 0 { - return nil, bufio.ErrNegativeCount - } - - cPos := b.cSeq.get() - pPos := b.pSeq.get() - - // If there's no data, then let's wait until there is some data - b.cCond.L.Lock() - for ; cPos >= pPos; pPos = b.pSeq.get() { - if b.isDone() { - b.cCond.L.Unlock() - return nil, io.EOF - } - - //b.cWait++ - b.cCond.Wait() - } - b.cCond.L.Unlock() - - // m = the number of bytes available. If m is more than what's requested (n), - // then we make m = n, basically peek max n bytes - m := pPos - cPos - err := error(nil) - - if m >= int64(n) { - m = int64(n) - } else { - err = ErrInsufficientData - } - - // There's data to peek. The size of the data could be <= n. - if cPos+m <= pPos { - cindex := cPos & b.mask - - // If cindex (index relative to buffer) + n is more than buffer size, that means - // the data wrapped - if cindex+m > b.size { - // reset the tmp buffer - b.tmp = b.tmp[0:0] - - l := len(b.buf[cindex:]) - b.tmp = append(b.tmp, b.buf[cindex:]...) - b.tmp = append(b.tmp, b.buf[0:m-int64(l)]...) - return b.tmp, err - } - - return b.buf[cindex : cindex+m], err - } - - return nil, ErrInsufficientData -} - -// ReadWait waits for for n bytes to be ready. If there's not enough data, then it will -// wait until there's enough. This differs from ReadPeek or Readin that Peek will -// return whatever is available and won't wait for full count. -func (b *Type) ReadWait(n int) ([]byte, error) { - if int64(n) > b.size { - return nil, bufio.ErrBufferFull - } - - if n < 0 { - return nil, bufio.ErrNegativeCount - } - - cPos := b.cSeq.get() - pPos := b.pSeq.get() - - // This is the magic read-to position. The producer position must be equal or - // greater than the next position we read to. - next := cPos + int64(n) - - // If there's no data, then let's wait until there is some data - b.cCond.L.Lock() - for ; next > pPos; pPos = b.pSeq.get() { - if b.isDone() { - b.cCond.L.Unlock() - return nil, io.EOF - } - - b.cCond.Wait() - } - b.cCond.L.Unlock() - - //if b.isDone() { - // return nil, io.EOF - //} - - // If we are here that means we have at least n bytes of data available. - cIndex := cPos & b.mask - - // If cIndex (index relative to buffer) + n is more than buffer size, that means - // the data wrapped - if cIndex+int64(n) > b.size { - // reset the tmp buffer - b.tmp = b.tmp[0:0] - - l := len(b.buf[cIndex:]) - b.tmp = append(b.tmp, b.buf[cIndex:]...) - b.tmp = append(b.tmp, b.buf[0:n-l]...) - return b.tmp[:n], nil - } - - return b.buf[cIndex : cIndex+int64(n)], nil -} - -// ReadCommit Commit moves the cursor forward by n bytes. It behaves like Read() except it doesn't -// return any data. If there's enough data, then the cursor will be moved forward and -// n will be returned. If there's not enough data, then the cursor will move forward -// as much as possible, then return the number of positions (bytes) moved. -func (b *Type) ReadCommit(n int) (int, error) { - if int64(n) > b.size { - return 0, bufio.ErrBufferFull - } - - if n < 0 { - return 0, bufio.ErrNegativeCount - } - - cPos := b.cSeq.get() - pPos := b.pSeq.get() - - // If consumer position is at least n less than producer position, that means - // we have enough data to fill p. There are two scenarios that could happen: - // 1. cindex + n < buffer size, in this case, we can just copy() data from - // buffer to p, and copy will just copy enough to fill p and stop. - // The number of bytes copied will be len(p). - // 2. cindex + n > buffer size, this means the data will wrap around to the - // the beginning of the buffer. In thise case, we can also just copy data from - // buffer to p, and copy will just copy until the end of the buffer and stop. - // The number of bytes will NOT be len(p) but less than that. - if cPos+int64(n) <= pPos { - b.cSeq.set(cPos + int64(n)) - b.pCond.L.Lock() - b.pCond.Broadcast() - b.pCond.L.Unlock() - return n, nil - } - - return 0, ErrInsufficientData -} - -// WriteWait waits for n bytes to be available in the buffer and then returns -// 1. the slice pointing to the location in the buffer to be filled -// 2. a boolean indicating whether the bytes available wraps around the ring -// 3. any errors encountered. If there's error then other return values are invalid -func (b *Type) WriteWait(n int) ([]byte, bool, error) { - start, cnt, err := b.waitForWriteSpace(n) - if err != nil { - return nil, false, err - } - - pStart := start & b.mask - if pStart+int64(cnt) > b.size { - return b.buf[pStart:], true, nil - } - - return b.buf[pStart : pStart+int64(cnt)], false, nil -} - -// WriteCommit write with commit -func (b *Type) WriteCommit(n int) (int, error) { - start, cnt, err := b.waitForWriteSpace(n) - if err != nil { - return 0, err - } - - // If we are here then there's enough bytes to commit - b.pSeq.set(start + int64(cnt)) - - b.cCond.L.Lock() - b.cCond.Broadcast() - b.cCond.L.Unlock() - - return cnt, nil -} - -// Send to -func (b *Type) Send(from [][]byte) (int, error) { - defer func() { - if int64(len(b.ExternalBuf)) > b.size { - b.ExternalBuf = make([]byte, b.size) - } - }() - - var total int - - for _, s := range from { - remaining := len(s) - offset := 0 - for remaining > 0 { - toWrite := remaining - if toWrite > int(b.Size()) { - toWrite = int(b.Size()) - } - - var wrote int - var err error - - if wrote, err = b.Write(s[offset : offset+toWrite]); err != nil { - return 0, err - } - - remaining -= wrote - offset += wrote - } - total += len(s) - } - - return total, nil -} - -func (b *Type) waitForWriteSpace(n int) (int64, int, error) { - if b.isDone() { - return 0, 0, io.EOF - } - - // The current producer position, remember it's a forever inreasing int64, - // NOT the position relative to the buffer - pPos := b.pSeq.get() - - // The next producer position we will get to if we write len(p) - next := pPos + int64(n) - - // For the producer, gate is the previous consumer sequence. - gate := b.pSeq.gate - - wrap := next - b.size - - // If wrap point is greater than gate, that means the consumer hasn't read - // some of the data in the buffer, and if we read in additional data and put - // into the buffer, we would overwrite some of the unread data. It means we - // cannot do anything until the customers have passed it. So we wait... - // - // Let's say size = 16, block = 4, pPos = 0, gate = 0 - // then next = 4 (0+4), and wrap = -12 (4-16) - // _______________________________________________________________________ - // | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | - // ----------------------------------------------------------------------- - // ^ ^ - // pPos, next - // gate - // - // So wrap (-12) > gate (0) = false, and gate (0) > pPos (0) = false also, - // so we move on (no waiting) - // - // Now if we get to pPos = 14, gate = 12, - // then next = 18 (4+14) and wrap = 2 (18-16) - // - // So wrap (2) > gate (12) = false, and gate (12) > pPos (14) = false aos, - // so we move on again - // - // Now let's say we have pPos = 14, gate = 0 still (nothing read), - // then next = 18 (4+14) and wrap = 2 (18-16) - // - // So wrap (2) > gate (0) = true, which means we have to wait because if we - // put data into the slice to the wrap point, it would overwrite the 2 bytes - // that are currently unread. - // - // Another scenario, let's say pPos = 100, gate = 80, - // then next = 104 (100+4) and wrap = 88 (104-16) - // - // So wrap (88) > gate (80) = true, which means we have to wait because if we - // put data into the slice to the wrap point, it would overwrite the 8 bytes - // that are currently unread. - // - if wrap > gate || gate > pPos { - var cPos int64 - b.pCond.L.Lock() - for cPos = b.cSeq.get(); wrap > cPos; cPos = b.cSeq.get() { - if b.isDone() { - return 0, 0, io.EOF - } - - //b.pWait++ - b.pCond.Wait() - } - - b.pSeq.gate = cPos - b.pCond.L.Unlock() - } - - return pPos, n, nil -} - -func (b *Type) isDone() bool { - return atomic.LoadInt64(&b.done) == 1 -} - -func ringCopy(dst, src []byte, start int64) int { - n := len(src) - - var i int - var l int - - for n > 0 { - l = copy(dst[start:], src[i:]) - i += l - n -= l - - if n > 0 { - start = 0 - } - } - - return i -} - -func powerOfTwo64(n int64) bool { - return n != 0 && (n&(n-1)) == 0 -} - -func roundUpPowerOfTwo64(n int64) int64 { - n-- - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n |= n >> 32 - n++ - - return n -} diff --git a/buffer/buffer_test.go b/buffer/buffer_test.go deleted file mode 100644 index 75aa7e5..0000000 --- a/buffer/buffer_test.go +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bytes" - "io" - "testing" - "time" - - "bufio" - - "github.com/stretchr/testify/require" -) - -func TestBufferSequence(t *testing.T) { - seq := newSequence() - - seq.set(100) - require.Equal(t, int64(100), seq.get()) - - seq.set(20000) - require.Equal(t, int64(20000), seq.get()) -} - -func TestBufferReadFrom(t *testing.T) { - testFillBuffer(t, 144, 16384) - testFillBuffer(t, 2048, 16384) - testFillBuffer(t, 3072, 16384) -} - -func TestBufferReadBytes(t *testing.T) { - buf := testFillBuffer(t, 2048, 16384) - - testReadBytes(t, buf) -} - -func TestBufferCommitBytes(t *testing.T) { - buf := testFillBuffer(t, 2048, 16384) - - testCommit(t, buf) -} - -func TestBufferConsumerProducerRead(t *testing.T) { - buf, err := New(16384) - - require.NoError(t, err) - - testRead(t, buf) -} - -func TestBufferConsumerProducerWriteTo(t *testing.T) { - buf, err := New(16384) - - require.NoError(t, err) - - testWriteTo(t, buf) -} - -func TestBufferConsumerProducerPeekCommit(t *testing.T) { - buf, err := New(16384) - - require.NoError(t, err) - - testPeekCommit(t, buf) -} - -func TestBufferPeek(t *testing.T) { - buf := testFillBuffer(t, 2048, 16384) - - peekBuffer(t, buf, 100) - peekBuffer(t, buf, 1000) -} - -func TestBufferNew(t *testing.T) { - _, err := New(-1) - require.EqualError(t, bufio.ErrNegativeCount, err.Error()) - - _, err = New(863) - require.Error(t, err) - - _, err = New(1024) - require.Error(t, err) - - _, err = New(39666) - require.Error(t, err) -} - -func TestBufferID(t *testing.T) { - buf, err := New(0) - require.NoError(t, err) - - require.NotEqual(t, 0, buf.ID()) -} - -func TestBufferSize(t *testing.T) { - buf, err := New(0) - require.NoError(t, err) - - require.NotEqual(t, 0, buf.Size()) -} - -func TestBufferClosed(t *testing.T) { - buf, err := New(0) - require.NoError(t, err) - - p := make([]byte, 1024) - for i := range p { - p[i] = 'a' - } - - _, err = buf.ReadFrom(bytes.NewBuffer(p)) - require.EqualError(t, io.EOF, err.Error()) - - go func() { - buf.Close() // nolint: errcheck - }() -} - -func BenchmarkBufferConsumerProducerRead(b *testing.B) { - buf, _ := New(0) - benchmarkRead(b, buf) -} - -func testFillBuffer(t *testing.T, bufsize, ringsize int64) *Type { - buf, err := New(ringsize) - - require.NoError(t, err) - - fillBuffer(t, buf, bufsize) - - require.Equal(t, int(bufsize), buf.Len()) - - return buf -} - -func fillBuffer(t *testing.T, buf *Type, bufsize int64) { - p := make([]byte, bufsize) - for i := range p { - p[i] = 'a' - } - - n, err := buf.ReadFrom(bytes.NewBuffer(p)) - - require.Equal(t, bufsize, n) - require.Equal(t, err, io.EOF) -} - -func peekBuffer(t *testing.T, buf *Type, n int) { - pkbuf, err := buf.ReadPeek(n) - - require.NoError(t, err) - require.Equal(t, n, len(pkbuf)) - - for _, b := range pkbuf { - require.Equal(t, byte('a'), b) - } -} - -func testPeekCommit(t *testing.T, buf *Type) { - n := 20000 - - go func(n int64) { - fillBuffer(t, buf, n) - }(int64(n)) - - i := 0 - - for n > 0 { - pkbuf, _ := buf.ReadPeek(1024) - l, err := buf.ReadCommit(len(pkbuf)) - - require.NoError(t, err) - - n -= l - i += l - } -} - -func testWriteTo(t *testing.T, buf *Type) { - n := int64(20000) - - go func(n int64) { - fillBuffer(t, buf, n) - time.Sleep(time.Millisecond * 100) - buf.Close() // nolint: errcheck - }(n) - - m, err := buf.WriteTo(bytes.NewBuffer(make([]byte, n))) - - require.Equal(t, io.EOF, err) - require.Equal(t, int64(20000), m) -} - -func testRead(t *testing.T, buf *Type) { - n := int64(20000) - - go func(n int64) { - fillBuffer(t, buf, n) - }(n) - - p := make([]byte, n) - i := 0 - - for n > 0 { - l, err := buf.Read(p[i:]) - - require.NoError(t, err) - - n -= int64(l) - i += l - } -} - -func testCommit(t *testing.T, buf *Type) { - n, err := buf.ReadCommit(256) - - require.NoError(t, err) - require.Equal(t, 256, n) - - _, err = buf.ReadCommit(2048) - - require.Equal(t, ErrInsufficientData, err) -} - -func testReadBytes(t *testing.T, buf *Type) { - p := make([]byte, 256) - n, err := buf.Read(p) - - require.NoError(t, err) - require.Equal(t, 256, n) - - p2 := make([]byte, 4096) - n, err = buf.Read(p2) - - require.NoError(t, err) - require.Equal(t, 2048-256, n) -} - -func benchmarkRead(b *testing.B, buf *Type) { - n := int64(b.N) - - go func(n int64) { - p := make([]byte, n) - buf.ReadFrom(bytes.NewBuffer(p)) // nolint: errcheck - }(n) - - p := make([]byte, n) - i := 0 - - for n > 0 { - l, _ := buf.Read(p[i:]) - - n -= int64(l) - i += l - } -} diff --git a/connection/net.go b/connection/net.go index 5f6c21b..b772098 100644 --- a/connection/net.go +++ b/connection/net.go @@ -9,7 +9,7 @@ import ( "errors" "sync" - "github.com/troian/surgemq/buffer" + "github.com/troian/goring" "github.com/troian/surgemq/configuration" "github.com/troian/surgemq/packet" "github.com/troian/surgemq/systree" @@ -41,7 +41,7 @@ type netConfig struct { on onProcess // Conn is network connection - conn io.Closer + conn net.Conn // PacketsMetric interface to metric packets packetsMetric systree.PacketsMetric @@ -56,15 +56,34 @@ type netConfig struct { protoVersion packet.ProtocolVersion } +type keepAlive struct { + period time.Duration + conn net.Conn + timer *time.Timer +} + +func (k *keepAlive) Read(b []byte) (int, error) { + if k.period > 0 { + if !k.timer.Stop() { + <-k.timer.C + } + k.timer.Reset(k.period) + } + return k.conn.Read(b) +} + // netConn implementation of the connection type netConn struct { // Incoming data buffer. Bytes are read from the connection and put in here - in *buffer.Type - - // Outgoing data buffer. Bytes written here are in turn written out to the connection - out *buffer.Type + in *goring.Buffer config *netConfig + sendTicker *time.Timer + currLock sync.Mutex + currOutBuffer net.Buffers + outBuffers chan net.Buffers + keepAlive keepAlive + // Wait for the various goroutines to finish starting and stopping wg struct { routines struct { @@ -85,21 +104,10 @@ type netConn struct { // Quit signal for determining when this service should end. If channel is closed, then exit expireIn *time.Duration done chan struct{} - wmu sync.Mutex onStop types.Once will bool } -type netReader interface { - io.Reader - SetReadDeadline(t time.Time) error -} - -type timeoutReader struct { - d time.Duration - conn netReader -} - // newNet connection func newNet(config *netConfig) (f *netConn, err error) { defer func() { @@ -109,11 +117,15 @@ func newNet(config *netConfig) (f *netConn, err error) { }() f = &netConn{ - config: config, - done: make(chan struct{}), - will: true, + config: config, + done: make(chan struct{}), + will: true, + outBuffers: make(chan net.Buffers, 5), + sendTicker: time.NewTimer(5 * time.Millisecond), } + f.sendTicker.Stop() + f.log.prod = configuration.GetProdLogger().Named("session.conn." + config.id) f.log.dev = configuration.GetDevLogger().Named("session.conn." + config.id) @@ -121,15 +133,16 @@ func newNet(config *netConfig) (f *netConn, err error) { f.wg.conn.stopped.Add(1) // Create the incoming ring buffer - f.in, err = buffer.New(buffer.DefaultBufferSize) + f.in, err = goring.New(goring.DefaultBufferSize) if err != nil { return nil, err } - // Create the outgoing ring buffer - f.out, err = buffer.New(buffer.DefaultBufferSize) - if err != nil { - return nil, err + f.keepAlive.conn = f.config.conn + + if f.config.keepAlive > 0 { + f.keepAlive.period = time.Second * time.Duration(f.config.keepAlive) + f.keepAlive.period = f.keepAlive.period + (f.keepAlive.period / 2) } return f, nil @@ -139,7 +152,11 @@ func newNet(config *netConfig) (f *netConn, err error) { func (s *netConn) start() { defer s.wg.conn.started.Done() - s.wg.routines.stopped.Add(3) + if s.keepAlive.period > 0 { + s.wg.routines.stopped.Add(4) + } else { + s.wg.routines.stopped.Add(3) + } // these routines must start in specified order // and next proceed next one only when previous finished @@ -151,6 +168,12 @@ func (s *netConn) start() { go s.processIncoming() s.wg.routines.started.Wait() + if s.keepAlive.period > 0 { + s.wg.routines.started.Add(1) + go s.readTimeOutWorker() + s.wg.routines.started.Wait() + } + s.wg.routines.started.Add(1) go s.receiver() s.wg.routines.started.Wait() @@ -183,9 +206,7 @@ func (s *netConn) stop(reason *packet.ReasonCode) { s.log.prod.Error("close input buffer error", zap.String("ClientID", s.config.id), zap.Error(err)) } - if err := s.out.Close(); err != nil { - s.log.prod.Error("close output buffer error", zap.String("ClientID", s.config.id), zap.Error(err)) - } + s.sendTicker.Stop() // Wait for all the connection goroutines are finished s.wg.routines.stopped.Wait() @@ -196,14 +217,6 @@ func (s *netConn) stop(reason *packet.ReasonCode) { }) } -// Read -func (r timeoutReader) Read(b []byte) (int, error) { - if err := r.conn.SetReadDeadline(time.Now().Add(r.d)); err != nil { - return 0, err - } - return r.conn.Read(b) -} - // isDone func (s *netConn) isDone() bool { select { @@ -314,45 +327,61 @@ func (s *netConn) processIncoming() { } } +func (s *netConn) readTimeOutWorker() { + defer s.onRoutineReturn() + + s.keepAlive.timer = time.NewTimer(s.keepAlive.period) + s.wg.routines.started.Done() + + select { + case <-s.keepAlive.timer.C: + s.log.prod.Error("Keep alive timed out") + return + case <-s.done: + s.keepAlive.timer.Stop() + return + } +} + // receiver reads data from the network, and writes the data into the incoming buffer func (s *netConn) receiver() { defer s.onRoutineReturn() s.wg.routines.started.Done() - switch conn := s.config.conn.(type) { - case net.Conn: - keepAlive := time.Second * time.Duration(s.config.keepAlive) - r := timeoutReader{ - d: keepAlive + (keepAlive / 2), - conn: conn, - } - - for { - if _, err := s.in.ReadFrom(r); err != nil { - return - } - } - default: - s.log.prod.Error("Invalid connection type", zap.String("ClientID", s.config.id)) - } + s.in.ReadFrom(&s.keepAlive) // nolint: errcheck } // sender writes data from the outgoing buffer to the network func (s *netConn) sender() { defer s.onRoutineReturn() - s.wg.routines.started.Done() - switch conn := s.config.conn.(type) { - case net.Conn: - for { - if _, err := s.out.WriteTo(conn); err != nil { + for { + bufs := net.Buffers{} + select { + case <-s.sendTicker.C: + s.currLock.Lock() + s.outBuffers <- s.currOutBuffer + s.currOutBuffer = net.Buffers{} + s.currLock.Unlock() + case buf, ok := <-s.outBuffers: + s.sendTicker.Stop() + if !ok { + return + } + bufs = buf + case <-s.done: + s.sendTicker.Stop() + close(s.outBuffers) + return + } + + if len(bufs) > 0 { + if _, err := bufs.WriteTo(s.config.conn); err != nil { return } } - default: - s.log.prod.Error("Invalid connection type", zap.String("ClientID", s.config.id)) } } @@ -364,7 +393,7 @@ func (s *netConn) peekMessageSize() (packet.Type, int, error) { cnt := 2 if s.in == nil { - err = buffer.ErrNotReady + err = goring.ErrNotReady return 0, 0, err } @@ -420,7 +449,7 @@ func (s *netConn) readMessage(total int) (packet.Provider, int, error) { var msg packet.Provider if s.in == nil { - err = buffer.ErrNotReady + err = goring.ErrNotReady return nil, 0, err } @@ -445,7 +474,7 @@ func (s *netConn) readMessage(total int) (packet.Provider, int, error) { s.log.prod.Error("Incoming and outgoing length does not match", zap.Int("in", total), zap.Int("out", dTotal)) - return nil, 0, buffer.ErrNotReady + return nil, 0, goring.ErrNotReady } return msg, n, err @@ -454,26 +483,37 @@ func (s *netConn) readMessage(total int) (packet.Provider, int, error) { // WriteMessage writes a message to the outgoing buffer func (s *netConn) WriteMessage(msg packet.Provider, lastMessage bool) (int, error) { if s.isDone() { - return 0, buffer.ErrNotReady + return 0, goring.ErrNotReady } if lastMessage { close(s.done) } - defer s.wmu.Unlock() - s.wmu.Lock() + var total int + + expectedSize, err := msg.Size() + if err != nil { + return 0, err + } - if s.out == nil { - return 0, buffer.ErrNotReady + buf := make([]byte, expectedSize) + total, err = msg.Encode(buf) + if err != nil { + return 0, err } - var total int - var err error + s.currLock.Lock() + s.currOutBuffer = append(s.currOutBuffer, buf) + if len(s.currOutBuffer) == 1 { + s.sendTicker.Reset(1 * time.Millisecond) + } - if total, err = packet.WriteToBuffer(msg, s.out); err == nil { - s.config.packetsMetric.Sent(msg.Type()) + if len(s.currOutBuffer) == 10 { + s.outBuffers <- s.currOutBuffer + s.currOutBuffer = net.Buffers{} } + s.currLock.Unlock() return total, err } diff --git a/connection/netCallbacks.go b/connection/netCallbacks.go index 199c93b..5d4774d 100644 --- a/connection/netCallbacks.go +++ b/connection/netCallbacks.go @@ -29,9 +29,6 @@ func (s *Type) getState() *persistenceTypes.SessionMessages { outMessages := [][]byte{} unAckMessages := [][]byte{} - //messages := s.publisher.messages.GetAll() - - //for _, v := range messages { var next *list.Element for elem := s.publisher.messages.Front(); elem != nil; elem = next { next = elem.Next() @@ -258,7 +255,7 @@ func (s *Type) onSubscribe(msg *packet.Subscribe) error { t := kv.Key.(string) ops := kv.Value.(packet.SubscriptionOptions) - reason := packet.CodeSuccess + reason := packet.CodeSuccess // nolint: ineffassign //authorized := true // TODO: check permissions here diff --git a/examples/surgemq/surgemq.go b/examples/surgemq/surgemq.go index 12a4bcb..6632f2e 100644 --- a/examples/surgemq/surgemq.go +++ b/examples/surgemq/surgemq.go @@ -29,6 +29,7 @@ import ( "go.uber.org/zap" _ "net/http/pprof" + "runtime" _ "runtime/debug" ) @@ -44,7 +45,7 @@ func main() { var err error logger.Info("Starting application") - + logger.Info("Allocated cores", zap.Int("GOMAXPROCS", runtime.GOMAXPROCS(0))) viper.SetConfigName("config") viper.AddConfigPath("conf") viper.SetConfigType("json") @@ -119,17 +120,19 @@ func main() { logger.Error("Couldn't start listener", zap.Error(err)) } - configWs := transport.NewConfigWS( - &transport.Config{ - Port: 8080, - AuthManager: authMng, - }) - - if err = srv.ListenAndServe(configWs); err != nil { - logger.Error("Couldn't start listener", zap.Error(err)) - } + //configWs := transport.NewConfigWS( + // &transport.Config{ + // Port: 8080, + // AuthManager: authMng, + // }) + // + //if err = srv.ListenAndServe(configWs); err != nil { + // logger.Error("Couldn't start listener", zap.Error(err)) + //} go func() { + runtime.SetBlockProfileRate(1) + runtime.SetMutexProfileFraction(1) logger.Info(http.ListenAndServe("localhost:6061", nil).Error()) }() diff --git a/packet/connack_test.go b/packet/connack_test.go index dc22e06..50d8a60 100644 --- a/packet/connack_test.go +++ b/packet/connack_test.go @@ -18,7 +18,7 @@ import ( "testing" "github.com/stretchr/testify/require" - "github.com/troian/surgemq/buffer" + "github.com/troian/goring" ) func TestConnAckMessageFields(t *testing.T) { @@ -184,7 +184,7 @@ func TestConnAckEncodeEnsureSize(t *testing.T) { } func TestConnAckCodeWrite(t *testing.T) { - buf, err := buffer.New(16384) + buf, err := goring.New(16384) require.NoError(t, err) buf.ExternalBuf = make([]byte, 1) diff --git a/packet/routines.go b/packet/routines.go index 912f87a..673c76b 100644 --- a/packet/routines.go +++ b/packet/routines.go @@ -3,11 +3,11 @@ package packet import ( "encoding/binary" - "github.com/troian/surgemq/buffer" + "github.com/troian/goring" ) // WriteToBuffer encode and send message into ring buffer -func WriteToBuffer(msg Provider, to *buffer.Type) (int, error) { +func WriteToBuffer(msg Provider, to *goring.Buffer) (int, error) { expectedSize, err := msg.Size() if err != nil { return 0, err @@ -38,14 +38,10 @@ func ReadLPBytes(buf []byte) ([]byte, int, error) { n = int(binary.BigEndian.Uint16(buf)) total += 2 - if len(buf[total:]) < n { - return nil, total, ErrInsufficientDataSize - } - // Check for malformed length-prefixed field // if remaining space is less than length-prefixed size the packet seems to be broken if len(buf[total:]) < n { - return nil, total, ErrInvalidLength + return nil, total, ErrInsufficientDataSize } total += n diff --git a/routines/misc.go b/routines/misc.go index 31c8cd1..a443054 100644 --- a/routines/misc.go +++ b/routines/misc.go @@ -116,9 +116,3 @@ func WriteMessageBuffer(c io.Closer, b []byte) error { _, err := conn.Write(b) return err } - -// Copied from http://golang.org/src/pkg/net/timeout_test.go -//func isTimeout(err error) bool { -// e, ok := err.(net.Error) -// return ok && e.Timeout() -//} diff --git a/transport/base.go b/transport/base.go index 04133d8..605f773 100644 --- a/transport/base.go +++ b/transport/base.go @@ -115,6 +115,8 @@ func (c *baseConfig) handleConnection(conn conn) { } } } else { + // Disable read deadline. Will set it later if keep-alive interval is bigger than 0 + conn.SetReadDeadline(time.Time{}) // nolint: errcheck switch r := req.(type) { case *packet.Connect: m, _ := packet.NewMessage(req.Version(), packet.CONNACK)