Skip to content

Commit

Permalink
proto: upgrade vtprotobuf
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <[email protected]>
  • Loading branch information
vmg committed Mar 19, 2024
1 parent a41a35a commit 42d98a9
Show file tree
Hide file tree
Showing 34 changed files with 14,232 additions and 15,671 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ require (
github.com/pires/go-proxyproto v0.7.0
github.com/pkg/errors v0.9.1
github.com/planetscale/pargzip v0.0.0-20201116224723-90c7fc03ea8a
github.com/planetscale/vtprotobuf v0.5.0
github.com/planetscale/vtprotobuf v0.6.0
github.com/prometheus/client_golang v1.19.0
github.com/prometheus/common v0.49.0
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ github.com/planetscale/pargzip v0.0.0-20201116224723-90c7fc03ea8a h1:y0OpQ4+5tKx
github.com/planetscale/pargzip v0.0.0-20201116224723-90c7fc03ea8a/go.mod h1:GJFUzQuXIoB2Kjn1ZfDhJr/42D5nWOqRcIQVgCxTuIE=
github.com/planetscale/vtprotobuf v0.5.0 h1:l8PXm6Colok5z6qQLNhAj2Jq5BfoMTIHxLER5a6nDqM=
github.com/planetscale/vtprotobuf v0.5.0/go.mod h1:wm1N3qk9G/4+VM1WhpkLbvY/d8+0PbwYYpP5P5VhTks=
github.com/planetscale/vtprotobuf v0.6.0 h1:nBeETjudeJ5ZgBHUz1fVHvbqUKnYOXNhsIEabROxmNA=
github.com/planetscale/vtprotobuf v0.6.0/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down
30 changes: 30 additions & 0 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,36 @@ func (c *Conn) readHeaderFrom(r io.Reader) (int, error) {
return int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16), nil
}

func (c *Conn) readPacketAsProto(b *queryResultBuilder) ([]byte, error) {
r := c.getReader()

length, err := c.readHeaderFrom(r)
if err != nil {
return nil, err
}

if length == 0 {
// This can be caused by the packet after a packet of
// exactly size MaxPacketSize.
return nil, nil
}

// Use the bufPool.
if length < MaxPacketSize {
buf := b.Packet(length)
c.currentEphemeralBuffer = bufPool.Get(length)
if _, err := io.ReadFull(r, buf); err != nil {
return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
}
return buf, nil
}

// Much slower path, revert to allocating everything from scratch.
// We're going to concatenate a lot of data anyway, can't really
// optimize this code path easily.
panic("TODO: large packets")
}

// readEphemeralPacket attempts to read a packet into buffer from sync.Pool. Do
// not use this method if the contents of the packet needs to be kept
// after the next readEphemeralPacket.
Expand Down
213 changes: 176 additions & 37 deletions go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,7 @@ func (c *Conn) writeComSetOption(operation uint16) error {
return nil
}

// readColumnDefinition reads the next Column Definition packet.
// Returns a SQLError.
func (c *Conn) readColumnDefinition(field *querypb.Field, index int) error {
colDef, err := c.readEphemeralPacket()
if err != nil {
return sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()

func parseColumnDefinition(colDef []byte, field *querypb.Field, index int) error {
// Catalog is ignored, always set to "def"
pos, ok := skipLenEncString(colDef, 0)
if !ok {
Expand Down Expand Up @@ -160,6 +152,7 @@ func (c *Conn) readColumnDefinition(field *querypb.Field, index int) error {
}

// Convert MySQL type to Vitess type.
var err error
field.Type, err = sqltypes.MySQLToType(t, int64(flags))
if err != nil {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "MySQLToType(%v,%v) failed for column %v: %v", t, flags, index, err)
Expand Down Expand Up @@ -191,41 +184,32 @@ func (c *Conn) readColumnDefinition(field *querypb.Field, index int) error {
return nil
}

// readColumnDefinitionType is a faster version of
// readColumnDefinition that only fills in the Type.
// Returns a SQLError.
func (c *Conn) readColumnDefinitionType(field *querypb.Field, index int) error {
colDef, err := c.readEphemeralPacket()
if err != nil {
return sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()

func parseColumnDefinitionType(colDef []byte, index int) (sqltypes.Type, error) {
// catalog, schema, table, orgTable, name and orgName are
// strings, all skipped.
pos, ok := skipLenEncString(colDef, 0)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v catalog failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v catalog failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v schema failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v schema failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v table failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v table failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v org_table failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v org_table failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v name failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v name failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v org_name failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "skipping col %v org_name failed", index)
}

// Skip length of fixed-length fields.
Expand All @@ -234,41 +218,66 @@ func (c *Conn) readColumnDefinitionType(field *querypb.Field, index int) error {
// characterSet is a uint16.
_, pos, ok = readUint16(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extracting col %v characterSet failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extracting col %v characterSet failed", index)
}

// columnLength is a uint32.
_, pos, ok = readUint32(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extracting col %v columnLength failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extracting col %v columnLength failed", index)
}

// type is one byte
t, pos, ok := readByte(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extracting col %v type failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extracting col %v type failed", index)
}

// flags is 2 bytes
flags, _, ok := readUint16(colDef, pos)
if !ok {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extracting col %v flags failed", index)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extracting col %v flags failed", index)
}

// Convert MySQL type to Vitess type.
field.Type, err = sqltypes.MySQLToType(t, int64(flags))
sqltype, err := sqltypes.MySQLToType(t, int64(flags))
if err != nil {
return sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "MySQLToType(%v,%v) failed for column %v: %v", t, flags, index, err)
return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "MySQLToType(%v,%v) failed for column %v: %v", t, flags, index, err)
}

// skip decimals
return sqltype, nil
}

return nil
// readColumnDefinition reads the next Column Definition packet.
// Returns a SQLError.
func (c *Conn) readColumnDefinition(field *querypb.Field, index int) error {
colDef, err := c.readEphemeralPacket()
if err != nil {
return sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()

return parseColumnDefinition(colDef, field, index)
}

// readColumnDefinitionType is a faster version of
// readColumnDefinition that only fills in the Type.
// Returns a SQLError.
func (c *Conn) readColumnDefinitionType(field *querypb.Field, index int) error {
colDef, err := c.readEphemeralPacket()
if err != nil {
return sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()

field.Type, err = parseColumnDefinitionType(colDef, index)
return err
}

// parseRow parses an individual row.
// Returns a SQLError.
func (c *Conn) parseRow(data []byte, fields []*querypb.Field, reader func([]byte, int) ([]byte, int, bool), result []sqltypes.Value) ([]sqltypes.Value, error) {
func parseRow(data []byte, fields []*querypb.Field, reader func([]byte, int) ([]byte, int, bool), result []sqltypes.Value) ([]sqltypes.Value, error) {
colNumber := len(fields)
if result == nil {
result = make([]sqltypes.Value, 0, colNumber)
Expand Down Expand Up @@ -315,7 +324,18 @@ func (c *Conn) parseRow(data []byte, fields []*querypb.Field, reader func([]byte
// 2. if the server closes the connection when a command is in flight,
// readComQueryResponse will fail, and we'll return CRServerLost(2013).
func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (result *sqltypes.Result, err error) {
result, more, err := c.ExecuteFetchMulti(query, maxrows, wantfields)
return c.ExecuteFetchOpt(query, ExecuteOptions{MaxRows: maxrows, WantFields: wantfields})
}

type ExecuteOptions struct {
MaxRows int
SizeHint int
WantFields bool
RawPackets bool
}

func (c *Conn) ExecuteFetchOpt(query string, opt ExecuteOptions) (*sqltypes.Result, error) {
result, more, err := c.ExecuteFetchMultiOpt(query, opt)
if more {
// Multiple results are unexpected. Prioritize this "unexpected" error over whatever error we got from the first result.
err = errors.Join(ErrExecuteFetchMultipleResults, err)
Expand Down Expand Up @@ -348,6 +368,10 @@ func (c *Conn) drainMoreResults(more bool, err error) error {
// It returns an additional 'more' flag. If it is set, you must fetch the additional
// results using ReadQueryResult.
func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (result *sqltypes.Result, more bool, err error) {
return c.ExecuteFetchMultiOpt(query, ExecuteOptions{MaxRows: maxrows, WantFields: wantfields})
}

func (c *Conn) ExecuteFetchMultiOpt(query string, opt ExecuteOptions) (result *sqltypes.Result, more bool, err error) {
defer func() {
if err != nil {
if sqlerr, ok := err.(*sqlerror.SQLError); ok {
Expand All @@ -361,11 +385,15 @@ func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (re
return nil, false, err
}

res, more, _, err := c.ReadQueryResult(maxrows, wantfields)
if opt.RawPackets {
result, more, _, err = c.ReadQueryResultAsProto(opt.MaxRows, opt.SizeHint)
} else {
result, more, _, err = c.ReadQueryResult(opt.MaxRows, opt.WantFields)
}
if err != nil {
return nil, false, err
}
return res, more, err
return result, more, err
}

// ExecuteFetchWithWarningCount is for fetching results and a warning count
Expand All @@ -389,6 +417,117 @@ func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfield
return res, warnings, err
}

func (c *Conn) ReadQueryResultAsProto(maxrows int, sizehint int) (*sqltypes.Result, bool, uint16, error) {
var packetOk PacketOK

// Get the result.
colNumber, err := c.readComQueryResponse(&packetOk)
if err != nil {
return nil, false, 0, err
}
more := packetOk.statusFlags&ServerMoreResultsExists != 0
warnings := packetOk.warnings
if colNumber == 0 {
// OK packet, means no results. Just use the numbers.
return &sqltypes.Result{
RowsAffected: packetOk.affectedRows,
InsertID: packetOk.lastInsertID,
SessionStateChanges: packetOk.sessionStateData,
StatusFlags: packetOk.statusFlags,
Info: packetOk.info,
}, more, warnings, nil
}

var data []byte
builder := newQueryResultBuilder(sizehint)

// Read column headers. One packet per column.
// Build the fields.
for i := 0; i < colNumber; i++ {
_, err = c.readPacketAsProto(&builder)
if err != nil {
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, "", "")
}
}

if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
// EOF is only present here if it's not deprecated.
data, err = c.readPacketAsProto(&builder)
if err != nil {
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}
if c.isEOFPacket(data) {
builder.DiscardLastField()
} else if isErrorPacket(data) {
return nil, false, 0, ParseErrorPacket(data)
} else {
return nil, false, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data)
}
}

builder.Packet(0)
var rowcount int

// read each row until EOF or OK packet.
for {
data, err = c.readPacketAsProto(&builder)
if err != nil {
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}

if c.isEOFPacket(data) {
result := &sqltypes.Result{}

// The deprecated EOF packets change means that this is either an
// EOF packet or an OK packet with the EOF type code.
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
var statusFlags uint16
warnings, statusFlags, err = parseEOFPacket(data)
if err != nil {
return nil, false, 0, err
}
more = (statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = statusFlags

builder.DiscardLastField()
} else {
var packetEof PacketOK
if err := c.parseOKPacket(&packetEof, data); err != nil {
return nil, false, 0, err
}
warnings = packetEof.warnings
more = (packetEof.statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = packetEof.statusFlags

builder.DiscardLastField()
builder.SessionStateChanges(packetEof.sessionStateData)
builder.Info(packetEof.info)
}

result.CachedProto = builder.Finish()
return result, more, warnings, nil

} else if isErrorPacket(data) {
// Error packet.
return nil, false, 0, ParseErrorPacket(data)
}

if maxrows == FETCH_NO_ROWS {
continue
}

// Check we're not over the limit before we add more.
if rowcount == maxrows {
if err := c.drainResults(); err != nil {
return nil, false, 0, err
}
return nil, false, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows)
}

rowcount++
}
}

// ReadQueryResult gets the result from the last written query.
func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, bool, uint16, error) {
var packetOk PacketOK
Expand Down Expand Up @@ -512,7 +651,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result,
}

// Regular row.
row, err := c.parseRow(data, result.Fields, readLenEncStringAsBytesCopy, nil)
row, err := parseRow(data, result.Fields, readLenEncStringAsBytesCopy, nil)
if err != nil {
c.recycleReadPacket()
return nil, false, 0, err
Expand Down
Loading

0 comments on commit 42d98a9

Please sign in to comment.