-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.go
134 lines (119 loc) · 3.08 KB
/
helpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package chanserv
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
"github.com/pierrec/lz4"
)
var ErrWrongSize = errors.New("wrong frame size")
var ErrWrongUncompressedSize = errors.New("wrong uncompressed frame size")
// FrameSizeLimit specifies the maximum size of payload in a frame,
// this limit may be increased or lifted in future.
const FrameSizeLimit = 100 * 1024 * 1024
var CompressionHeader = []byte("lz4!")
func writeFrame(wr io.Writer, frame []byte) (err error) {
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, uint64(len(frame)))
if _, err = wr.Write(buf); err != nil {
return
}
_, err = io.Copy(wr, bytes.NewReader(frame))
return
}
var hashtable [1 << 16]int
func writeCompressedFrame(wr io.Writer, frame []byte) (err error) {
comp := make([]byte, lz4.CompressBlockBound(len(frame)))
size, err := lz4.CompressBlock(frame, comp, hashtable[:])
if err != nil {
return err
}
if size >= len(frame) {
// discard compressed results
return writeFrame(wr, frame)
}
comp = comp[:size]
frameSize := size + len(CompressionHeader) + 8
buf := make([]byte, 8+len(CompressionHeader)+8)
binary.LittleEndian.PutUint64(buf, uint64(frameSize))
copy(buf[8:], CompressionHeader)
binary.LittleEndian.PutUint64(buf[12:], uint64(len(frame)))
if _, err = wr.Write(buf); err != nil {
return
}
_, err = io.Copy(wr, bytes.NewReader(comp))
return
}
func readFrame(r io.Reader, expectCompression bool) ([]byte, error) {
buf := make([]byte, 8)
if _, err := io.ReadFull(r, buf); err != nil {
return nil, err
}
frameSize := binary.LittleEndian.Uint64(buf)
// check frame size for bounds
if frameSize > FrameSizeLimit {
return nil, ErrWrongSize
}
framebuf := bytes.NewBuffer(make([]byte, 0, frameSize))
if _, err := io.CopyN(framebuf, r, int64(frameSize)); err != nil {
return nil, err
}
if !expectCompression {
return framebuf.Bytes(), nil
}
data := framebuf.Bytes()
if len(data) <= len(CompressionHeader)+8 {
// could not be a compressed frame
return data, nil
}
if !bytes.Equal(CompressionHeader, data[:4]) {
// doesn't have a compression header
return data, nil
}
uncompressedSize := binary.LittleEndian.Uint64(data[4:])
// check the size for bounds
if uncompressedSize > FrameSizeLimit*2 {
return nil, ErrWrongUncompressedSize
}
uncompressed := make([]byte, uncompressedSize)
size, err := lz4.UncompressBlock(data[12:], uncompressed)
if err != nil {
return nil, err
}
uncompressed = uncompressed[:size]
return uncompressed, nil
}
func needCompression(data []byte) bool {
return len(data) > len(CompressionHeader)+8
}
var timerPool sync.Pool
func init() {
timerPool.New = func() interface{} {
t := time.NewTimer(time.Minute)
t.Stop()
return t
}
}
func acceptTimeout(l net.Listener, d time.Duration) (conn net.Conn, err error) {
timeout := timerPool.Get().(*time.Timer)
timeout.Reset(d)
defer func() {
timeout.Stop()
timerPool.Put(timeout)
}()
done := make(chan struct{})
go func() {
conn, err = l.Accept()
close(done)
}()
select {
case <-done:
case <-timeout.C:
err = io.ErrNoProgress
l.Close()
}
return
}