diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 925dfbfa8e4..d61549c92ef 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -787,15 +787,15 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro // assuming CapabilityClientProtocol41 length += 4 // status_flags + warnings + hasSessionTrack := c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack + hasGtidData := hasSessionTrack && packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged + var gtidData []byte - if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack { + + if hasSessionTrack { length += lenEncStringSize(packetOk.info) // info - if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged { - gtidData = getLenEncString([]byte(packetOk.sessionStateData)) - gtidData = append([]byte{0x00}, gtidData...) - gtidData = getLenEncString(gtidData) - gtidData = append([]byte{0x03}, gtidData...) - gtidData = append(getLenEncInt(uint64(len(gtidData))), gtidData...) + if hasGtidData { + gtidData = encGtidData(packetOk.sessionStateData) length += len(gtidData) } } else { @@ -809,10 +809,10 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro data.writeLenEncInt(packetOk.lastInsertID) data.writeUint16(packetOk.statusFlags) data.writeUint16(packetOk.warnings) - if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack { + if hasSessionTrack { data.writeLenEncString(packetOk.info) - if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged { - data.writeEOFString(string(gtidData)) + if hasGtidData { + data.writeEOFBytes(gtidData) } } else { data.writeEOFString(packetOk.info) @@ -820,39 +820,6 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro return c.writeEphemeralPacket() } -func getLenEncString(value []byte) []byte { - data := getLenEncInt(uint64(len(value))) - return append(data, value...) -} - -func getLenEncInt(i uint64) []byte { - var data []byte - switch { - case i < 251: - data = append(data, byte(i)) - case i < 1<<16: - data = append(data, 0xfc) - data = append(data, byte(i)) - data = append(data, byte(i>>8)) - case i < 1<<24: - data = append(data, 0xfd) - data = append(data, byte(i)) - data = append(data, byte(i>>8)) - data = append(data, byte(i>>16)) - default: - data = append(data, 0xfe) - data = append(data, byte(i)) - data = append(data, byte(i>>8)) - data = append(data, byte(i>>16)) - data = append(data, byte(i>>24)) - data = append(data, byte(i>>32)) - data = append(data, byte(i>>40)) - data = append(data, byte(i>>48)) - data = append(data, byte(i>>56)) - } - return data -} - func (c *Conn) WriteErrorAndLog(format string, args ...interface{}) bool { return c.writeErrorAndLog(sqlerror.ERUnknownComError, sqlerror.SSNetError, format, args...) } @@ -1290,7 +1257,6 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) { c.PrepareData[c.StatementID] = prepare fld, err := handler.ComPrepare(c, queries[0], bindVars) - if err != nil { return c.writeErrorPacketFromErrorAndLog(err) } diff --git a/go/mysql/encoding.go b/go/mysql/encoding.go index 6b33ffabfc2..b9d015e8948 100644 --- a/go/mysql/encoding.go +++ b/go/mysql/encoding.go @@ -332,6 +332,53 @@ func readLenEncStringAsBytesCopy(data []byte, pos int) ([]byte, int, bool) { return result, pos + s, true } +// > encGtidData("xxx") +// +// [07 03 05 00 03 78 78 78] +// | | | | | |------| +// | | | | | ^-------- "xxx" +// | | | | ^------------ length of rest of bytes, 3 +// | | | ^--------------- fixed 0x00 +// | | ^------------------ length of rest of bytes, 5 +// | ^--------------------- fixed 0x03 (SESSION_TRACK_GTIDS) +// ^------------------------ length of rest of bytes, 7 +// +// This is ultimately lenencoded strings of length encoded strings, or: +// > lenenc(0x03 + lenenc(0x00 + lenenc(data))) +func encGtidData(data string) []byte { + const SessionTrackGtids = 0x03 + + // calculate total size up front to do 1 allocation + // encoded layout is: + // lenenc(0x03 + lenenc(0x00 + lenenc(data))) + dataSize := uint64(len(data)) + dataLenEncSize := uint64(lenEncIntSize(dataSize)) + + wrapSize := uint64(dataSize + dataLenEncSize + 1) + wrapLenEncSize := uint64(lenEncIntSize(wrapSize)) + + totalSize := uint64(wrapSize + wrapLenEncSize + 1) + totalLenEncSize := uint64(lenEncIntSize(totalSize)) + + gtidData := make([]byte, int(totalSize+totalLenEncSize)) + + pos := 0 + pos = writeLenEncInt(gtidData, pos, totalSize) + + gtidData[pos] = SessionTrackGtids + pos++ + + pos = writeLenEncInt(gtidData, pos, wrapSize) + + gtidData[pos] = 0x00 + pos++ + + pos = writeLenEncInt(gtidData, pos, dataSize) + writeEOFString(gtidData, pos, data) + + return gtidData +} + type coder struct { data []byte pos int @@ -397,3 +444,7 @@ func (d *coder) writeLenEncString(value string) { func (d *coder) writeEOFString(value string) { d.pos += copy(d.data[d.pos:], value) } + +func (d *coder) writeEOFBytes(value []byte) { + d.pos += copy(d.data[d.pos:], value) +} diff --git a/go/mysql/encoding_test.go b/go/mysql/encoding_test.go index 41f6c416993..c2b7013e2cb 100644 --- a/go/mysql/encoding_test.go +++ b/go/mysql/encoding_test.go @@ -18,6 +18,7 @@ package mysql import ( "bytes" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -72,7 +73,6 @@ func TestEncLenInt(t *testing.T) { // Check failed decoding. _, _, ok = readLenEncInt(test.encoded[:len(test.encoded)-1], 0) assert.False(t, ok, "readLenEncInt returned ok=true for shorter value %x", test.value) - } } @@ -355,6 +355,27 @@ func TestWriteZeroes(t *testing.T) { }) } +func TestEncGtidData(t *testing.T) { + tests := []struct { + data string + header []byte + }{ + {"", []byte{0x04, 0x03, 0x02, 0x00, 0x00}}, + {"xxx", []byte{0x07, 0x03, 0x05, 0x00, 0x03}}, + {strings.Repeat("x", 256), []byte{ + /* 264 */ 0xfc, 0x08, 0x01, + /* constant */ 0x03, + /* 260 */ 0xfc, 0x04, 0x01, + /* constant */ 0x00, + /* 256 */ 0xfc, 0x00, 0x01, + }}, + } + for _, test := range tests { + got := encGtidData(test.data) + assert.Equal(t, append(test.header, test.data...), got) + } +} + func BenchmarkEncWriteInt(b *testing.B) { buf := make([]byte, 16) @@ -451,3 +472,10 @@ func BenchmarkEncReadInt(b *testing.B) { } }) } + +func BenchmarkEncGtidData(b *testing.B) { + b.ReportAllocs() + for range b.N { + _ = encGtidData("xxx") + } +}