diff --git a/core/hack/version.go b/core/hack/version.go index 94449fc0..03a48ac1 100644 --- a/core/hack/version.go +++ b/core/hack/version.go @@ -1,6 +1,6 @@ package hack const ( - Version = "2017-10-24 13:57:15 +0800 @e992c6f" - Compile = "2017-11-10 20:44:38 +0800 by go version go1.9 darwin/amd64" + Version = "2017-11-10 20:46:09 +0800 @d33d0d5" + Compile = "2017-12-12 11:49:09 +0800 by go version go1.9.2 darwin/amd64" ) diff --git a/mysql/packetio.go b/mysql/packetio.go index b08808df..174ba41d 100644 --- a/mysql/packetio.go +++ b/mysql/packetio.go @@ -44,40 +44,50 @@ func NewPacketIO(conn net.Conn) *PacketIO { } func (p *PacketIO) ReadPacket() ([]byte, error) { - header := []byte{0, 0, 0, 0} - - if _, err := io.ReadFull(p.rb, header); err != nil { - return nil, ErrBadConn - } + var prevData []byte + for { + // read packet header + header := []byte{0, 0, 0, 0} + if _, err := io.ReadFull(p.rb, header); err != nil { + return nil, ErrBadConn + } - length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - if length < 1 { - return nil, fmt.Errorf("invalid payload length %d", length) - } + // packet length [24 bit] + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - sequence := uint8(header[3]) + // check packet sync [8 bit] + sequence := uint8(header[3]) + if sequence != p.Sequence { + return nil, fmt.Errorf("invalid sequence %d != %d", sequence, p.Sequence) + } + p.Sequence++ - if sequence != p.Sequence { - return nil, fmt.Errorf("invalid sequence %d != %d", sequence, p.Sequence) - } + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)−1 bytes long + if length == 0 { + // there was no previous packet + if prevData == nil { + return nil, fmt.Errorf("invalid payload length %d", length) + } + return prevData, nil + } - p.Sequence++ + // read packet body [length bytes] + data := make([]byte, length) + if _, err := io.ReadFull(p.rb, data); err != nil { + return nil, ErrBadConn + } - data := make([]byte, length) - if _, err := io.ReadFull(p.rb, data); err != nil { - return nil, ErrBadConn - } else { + // return data if this was the last packet if length < MaxPayloadLen { - return data, nil - } + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } - var buf []byte - buf, err = p.ReadPacket() - if err != nil { - return nil, ErrBadConn - } else { - return append(data, buf...), nil + return append(prevData, data...), nil } + prevData = append(prevData, data...) } } @@ -85,38 +95,35 @@ func (p *PacketIO) ReadPacket() ([]byte, error) { func (p *PacketIO) WritePacket(data []byte) error { length := len(data) - 4 - for length >= MaxPayloadLen { - - data[0] = 0xff - data[1] = 0xff - data[2] = 0xff - + for { + var size int + if length >= MaxPayloadLen { + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + size = MaxPayloadLen + } else { + data[0] = byte(length) + data[1] = byte(length >> 8) + data[2] = byte(length >> 16) + size = length + } data[3] = p.Sequence - if n, err := p.wb.Write(data[:4+MaxPayloadLen]); err != nil { + if n, err := p.wb.Write(data[:4+size]); err != nil { return ErrBadConn - } else if n != (4 + MaxPayloadLen) { + } else if n != (4 + size) { return ErrBadConn } else { p.Sequence++ - length -= MaxPayloadLen - data = data[MaxPayloadLen:] + if size != MaxPayloadLen { + return nil + } + length -= size + data = data[size:] + continue } } - - data[0] = byte(length) - data[1] = byte(length >> 8) - data[2] = byte(length >> 16) - data[3] = p.Sequence - - if n, err := p.wb.Write(data); err != nil { - return ErrBadConn - } else if n != len(data) { - return ErrBadConn - } else { - p.Sequence++ - return nil - } } func (p *PacketIO) WritePacketBatch(total, data []byte, direct bool) ([]byte, error) {