Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ReadPacket and WritePacket payload length is a multiple of #403

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/hack/version.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package hack

const (
Version = "2017-10-24 13:57:15 +0800 @e992c6f"
Compile = "2017-11-10 20:44:38 +0800 by go version go1.9 darwin/amd64"
Version = "2017-11-10 20:46:09 +0800 @d33d0d5"
Compile = "2017-12-12 11:49:09 +0800 by go version go1.9.2 darwin/amd64"
)
107 changes: 57 additions & 50 deletions mysql/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,79 +44,86 @@ func NewPacketIO(conn net.Conn) *PacketIO {
}

func (p *PacketIO) ReadPacket() ([]byte, error) {
header := []byte{0, 0, 0, 0}

if _, err := io.ReadFull(p.rb, header); err != nil {
return nil, ErrBadConn
}
var prevData []byte
for {
// read packet header
header := []byte{0, 0, 0, 0}
if _, err := io.ReadFull(p.rb, header); err != nil {
return nil, ErrBadConn
}

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
if length < 1 {
return nil, fmt.Errorf("invalid payload length %d", length)
}
// packet length [24 bit]
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

sequence := uint8(header[3])
// check packet sync [8 bit]
sequence := uint8(header[3])
if sequence != p.Sequence {
return nil, fmt.Errorf("invalid sequence %d != %d", sequence, p.Sequence)
}
p.Sequence++

if sequence != p.Sequence {
return nil, fmt.Errorf("invalid sequence %d != %d", sequence, p.Sequence)
}
// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)−1 bytes long
if length == 0 {
// there was no previous packet
if prevData == nil {
return nil, fmt.Errorf("invalid payload length %d", length)
}
return prevData, nil
}

p.Sequence++
// read packet body [length bytes]
data := make([]byte, length)
if _, err := io.ReadFull(p.rb, data); err != nil {
return nil, ErrBadConn
}

data := make([]byte, length)
if _, err := io.ReadFull(p.rb, data); err != nil {
return nil, ErrBadConn
} else {
// return data if this was the last packet
if length < MaxPayloadLen {
return data, nil
}
// zero allocations for non-split packets
if prevData == nil {
return data, nil
}

var buf []byte
buf, err = p.ReadPacket()
if err != nil {
return nil, ErrBadConn
} else {
return append(data, buf...), nil
return append(prevData, data...), nil
}
prevData = append(prevData, data...)
}
}

//data already have header
func (p *PacketIO) WritePacket(data []byte) error {
length := len(data) - 4

for length >= MaxPayloadLen {

data[0] = 0xff
data[1] = 0xff
data[2] = 0xff

for {
var size int
if length >= MaxPayloadLen {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = MaxPayloadLen
} else {
data[0] = byte(length)
data[1] = byte(length >> 8)
data[2] = byte(length >> 16)
size = length
}
data[3] = p.Sequence

if n, err := p.wb.Write(data[:4+MaxPayloadLen]); err != nil {
if n, err := p.wb.Write(data[:4+size]); err != nil {
return ErrBadConn
} else if n != (4 + MaxPayloadLen) {
} else if n != (4 + size) {
return ErrBadConn
} else {
p.Sequence++
length -= MaxPayloadLen
data = data[MaxPayloadLen:]
if size != MaxPayloadLen {
return nil
}
length -= size
data = data[size:]
continue
}
}

data[0] = byte(length)
data[1] = byte(length >> 8)
data[2] = byte(length >> 16)
data[3] = p.Sequence

if n, err := p.wb.Write(data); err != nil {
return ErrBadConn
} else if n != len(data) {
return ErrBadConn
} else {
p.Sequence++
return nil
}
}

func (p *PacketIO) WritePacketBatch(total, data []byte, direct bool) ([]byte, error) {
Expand Down