diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 4dcf87c4867..5d609b5fddf 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -1525,15 +1525,13 @@ type PacketOK struct { sessionStateData string } -func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) { +func (c *Conn) parseOKPacket(packetOK *PacketOK, in []byte) error { data := &coder{ data: in, pos: 1, // We already read the type. } - packetOK := &PacketOK{} - - fail := func(format string, args ...any) (*PacketOK, error) { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, format, args...) + fail := func(format string, args ...any) error { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, format, args...) } // Affected rows. @@ -1578,7 +1576,7 @@ func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) { if !ok || length == 0 { // In case we have no more data or a zero length string, there's no additional information so // we can return the packet. - return packetOK, nil + return nil } // Alright, now we need to read each sub packet from the session state change. @@ -1615,7 +1613,7 @@ func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) { } } - return packetOK, nil + return nil } // isErrorPacket determines whether or not the packet is an error packet. Mostly here for diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index da82a577753..5b218f38e13 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -246,7 +246,8 @@ func TestBasicPackets(t *testing.T) { require.NotEmpty(data) assert.EqualValues(data[0], OKPacket, "OKPacket") - packetOk, err := cConn.parseOKPacket(data) + var packetOk PacketOK + err = cConn.parseOKPacket(&packetOk, data) require.NoError(err) assert.EqualValues(12, packetOk.affectedRows) assert.EqualValues(34, packetOk.lastInsertID) @@ -272,7 +273,7 @@ func TestBasicPackets(t *testing.T) { require.NotEmpty(data) assert.EqualValues(data[0], OKPacket, "OKPacket") - packetOk, err = cConn.parseOKPacket(data) + err = cConn.parseOKPacket(&packetOk, data) require.NoError(err) assert.EqualValues(23, packetOk.affectedRows) assert.EqualValues(45, packetOk.lastInsertID) @@ -295,7 +296,7 @@ func TestBasicPackets(t *testing.T) { require.NotEmpty(data) assert.True(cConn.isEOFPacket(data), "expected EOF") - packetOk, err = cConn.parseOKPacket(data) + err = cConn.parseOKPacket(&packetOk, data) require.NoError(err) assert.EqualValues(12, packetOk.affectedRows) assert.EqualValues(34, packetOk.lastInsertID) @@ -690,7 +691,8 @@ func TestOkPackets(t *testing.T) { cConn.Capabilities = testCase.cc sConn.Capabilities = testCase.cc // parse the packet - packetOk, err := cConn.parseOKPacket(data) + var packetOk PacketOK + err := cConn.parseOKPacket(&packetOk, data) if testCase.expectedErr != "" { require.Error(t, err) require.Equal(t, testCase.expectedErr, err.Error()) @@ -699,7 +701,7 @@ func TestOkPackets(t *testing.T) { require.NoError(t, err, "failed to parse OK packet") // write the ok packet from server - err = sConn.writeOKPacket(packetOk) + err = sConn.writeOKPacket(&packetOk) require.NoError(t, err, "failed to write OK packet") // receive the ok packet on client diff --git a/go/mysql/query.go b/go/mysql/query.go index 758fa7cfe52..e1eaa7f9ea5 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -354,8 +354,9 @@ func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfield // ReadQueryResult gets the result from the last written query. func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, bool, uint16, error) { + var packetOk PacketOK // Get the result. - colNumber, packetOk, err := c.readComQueryResponse() + colNumber, err := c.readComQueryResponse(&packetOk) if err != nil { return nil, false, 0, err } @@ -441,8 +442,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, more = (statusFlags & ServerMoreResultsExists) != 0 result.StatusFlags = statusFlags } else { - packetOk, err := c.parseOKPacket(data) - if err != nil { + if err := c.parseOKPacket(&packetOk, data); err != nil { return nil, false, 0, err } warnings = packetOk.warnings @@ -497,35 +497,34 @@ func (c *Conn) drainResults() error { } } -func (c *Conn) readComQueryResponse() (int, *PacketOK, error) { +func (c *Conn) readComQueryResponse(packetOk *PacketOK) (int, error) { data, err := c.readEphemeralPacket() if err != nil { - return 0, nil, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) + return 0, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) } defer c.recycleReadPacket() if len(data) == 0 { - return 0, nil, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "invalid empty COM_QUERY response packet") + return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "invalid empty COM_QUERY response packet") } switch data[0] { case OKPacket: - packetOk, err := c.parseOKPacket(data) - return 0, packetOk, err + return 0, c.parseOKPacket(packetOk, data) case ErrPacket: // Error - return 0, nil, ParseErrorPacket(data) + return 0, ParseErrorPacket(data) case 0xfb: // Local infile - return 0, nil, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") + return 0, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") } n, pos, ok := readLenEncInt(data, 0) if !ok { - return 0, nil, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number") + return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number") } if pos != len(data) { - return 0, nil, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extra data in COM_QUERY response") + return 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extra data in COM_QUERY response") } - return int(n), &PacketOK{}, nil + return int(n), nil } // diff --git a/go/mysql/streaming_query.go b/go/mysql/streaming_query.go index 452f1af3206..3d0d9ef49e8 100644 --- a/go/mysql/streaming_query.go +++ b/go/mysql/streaming_query.go @@ -49,7 +49,8 @@ func (c *Conn) ExecuteStreamFetch(query string) (err error) { } // Get the result. - colNumber, _, err := c.readComQueryResponse() + var packetOk PacketOK + colNumber, err := c.readComQueryResponse(&packetOk) if err != nil { return err }