Skip to content

Commit

Permalink
go/mysql: improve GTID encoding for OK packet (#16361)
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Robenolt <[email protected]>
  • Loading branch information
mattrobenolt authored Jul 10, 2024
1 parent 16b05c1 commit bc32d84
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 45 deletions.
54 changes: 10 additions & 44 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,15 +787,15 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro
// assuming CapabilityClientProtocol41
length += 4 // status_flags + warnings

hasSessionTrack := c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack
hasGtidData := hasSessionTrack && packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged

var gtidData []byte
if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {

if hasSessionTrack {
length += lenEncStringSize(packetOk.info) // info
if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
gtidData = getLenEncString([]byte(packetOk.sessionStateData))
gtidData = append([]byte{0x00}, gtidData...)
gtidData = getLenEncString(gtidData)
gtidData = append([]byte{0x03}, gtidData...)
gtidData = append(getLenEncInt(uint64(len(gtidData))), gtidData...)
if hasGtidData {
gtidData = encGtidData(packetOk.sessionStateData)
length += len(gtidData)
}
} else {
Expand All @@ -809,50 +809,17 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro
data.writeLenEncInt(packetOk.lastInsertID)
data.writeUint16(packetOk.statusFlags)
data.writeUint16(packetOk.warnings)
if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
if hasSessionTrack {
data.writeLenEncString(packetOk.info)
if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
data.writeEOFString(string(gtidData))
if hasGtidData {
data.writeEOFBytes(gtidData)
}
} else {
data.writeEOFString(packetOk.info)
}
return c.writeEphemeralPacket()
}

func getLenEncString(value []byte) []byte {
data := getLenEncInt(uint64(len(value)))
return append(data, value...)
}

func getLenEncInt(i uint64) []byte {
var data []byte
switch {
case i < 251:
data = append(data, byte(i))
case i < 1<<16:
data = append(data, 0xfc)
data = append(data, byte(i))
data = append(data, byte(i>>8))
case i < 1<<24:
data = append(data, 0xfd)
data = append(data, byte(i))
data = append(data, byte(i>>8))
data = append(data, byte(i>>16))
default:
data = append(data, 0xfe)
data = append(data, byte(i))
data = append(data, byte(i>>8))
data = append(data, byte(i>>16))
data = append(data, byte(i>>24))
data = append(data, byte(i>>32))
data = append(data, byte(i>>40))
data = append(data, byte(i>>48))
data = append(data, byte(i>>56))
}
return data
}

func (c *Conn) WriteErrorAndLog(format string, args ...interface{}) bool {
return c.writeErrorAndLog(sqlerror.ERUnknownComError, sqlerror.SSNetError, format, args...)
}
Expand Down Expand Up @@ -1290,7 +1257,6 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
c.PrepareData[c.StatementID] = prepare

fld, err := handler.ComPrepare(c, queries[0], bindVars)

if err != nil {
return c.writeErrorPacketFromErrorAndLog(err)
}
Expand Down
51 changes: 51 additions & 0 deletions go/mysql/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,53 @@ func readLenEncStringAsBytesCopy(data []byte, pos int) ([]byte, int, bool) {
return result, pos + s, true
}

// > encGtidData("xxx")
//
// [07 03 05 00 03 78 78 78]
// | | | | | |------|
// | | | | | ^-------- "xxx"
// | | | | ^------------ length of rest of bytes, 3
// | | | ^--------------- fixed 0x00
// | | ^------------------ length of rest of bytes, 5
// | ^--------------------- fixed 0x03 (SESSION_TRACK_GTIDS)
// ^------------------------ length of rest of bytes, 7
//
// This is ultimately lenencoded strings of length encoded strings, or:
// > lenenc(0x03 + lenenc(0x00 + lenenc(data)))
func encGtidData(data string) []byte {
const SessionTrackGtids = 0x03

// calculate total size up front to do 1 allocation
// encoded layout is:
// lenenc(0x03 + lenenc(0x00 + lenenc(data)))
dataSize := uint64(len(data))
dataLenEncSize := uint64(lenEncIntSize(dataSize))

wrapSize := uint64(dataSize + dataLenEncSize + 1)
wrapLenEncSize := uint64(lenEncIntSize(wrapSize))

totalSize := uint64(wrapSize + wrapLenEncSize + 1)
totalLenEncSize := uint64(lenEncIntSize(totalSize))

gtidData := make([]byte, int(totalSize+totalLenEncSize))

pos := 0
pos = writeLenEncInt(gtidData, pos, totalSize)

gtidData[pos] = SessionTrackGtids
pos++

pos = writeLenEncInt(gtidData, pos, wrapSize)

gtidData[pos] = 0x00
pos++

pos = writeLenEncInt(gtidData, pos, dataSize)
writeEOFString(gtidData, pos, data)

return gtidData
}

type coder struct {
data []byte
pos int
Expand Down Expand Up @@ -397,3 +444,7 @@ func (d *coder) writeLenEncString(value string) {
func (d *coder) writeEOFString(value string) {
d.pos += copy(d.data[d.pos:], value)
}

func (d *coder) writeEOFBytes(value []byte) {
d.pos += copy(d.data[d.pos:], value)
}
30 changes: 29 additions & 1 deletion go/mysql/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mysql

import (
"bytes"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -72,7 +73,6 @@ func TestEncLenInt(t *testing.T) {
// Check failed decoding.
_, _, ok = readLenEncInt(test.encoded[:len(test.encoded)-1], 0)
assert.False(t, ok, "readLenEncInt returned ok=true for shorter value %x", test.value)

}
}

Expand Down Expand Up @@ -355,6 +355,27 @@ func TestWriteZeroes(t *testing.T) {
})
}

func TestEncGtidData(t *testing.T) {
tests := []struct {
data string
header []byte
}{
{"", []byte{0x04, 0x03, 0x02, 0x00, 0x00}},
{"xxx", []byte{0x07, 0x03, 0x05, 0x00, 0x03}},
{strings.Repeat("x", 256), []byte{
/* 264 */ 0xfc, 0x08, 0x01,
/* constant */ 0x03,
/* 260 */ 0xfc, 0x04, 0x01,
/* constant */ 0x00,
/* 256 */ 0xfc, 0x00, 0x01,
}},
}
for _, test := range tests {
got := encGtidData(test.data)
assert.Equal(t, append(test.header, test.data...), got)
}
}

func BenchmarkEncWriteInt(b *testing.B) {
buf := make([]byte, 16)

Expand Down Expand Up @@ -451,3 +472,10 @@ func BenchmarkEncReadInt(b *testing.B) {
}
})
}

func BenchmarkEncGtidData(b *testing.B) {
b.ReportAllocs()
for range b.N {
_ = encGtidData("xxx")
}
}

0 comments on commit bc32d84

Please sign in to comment.