diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index 61816b16d08..4f3d5fe893d 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -166,7 +166,7 @@ func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfi // Check if the context is already past its deadline before // trying to execute the query. if err := ctx.Err(); err != nil { - return nil, fmt.Errorf("%v before execution started", err) + return nil, vterrors.Errorf(vtrpcpb.Code_CANCELED, "%s before execution started", dbc.getErrorMessageFromContextError(ctx)) } now := time.Now() @@ -200,8 +200,8 @@ func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfi } } -// terminate kills the query or connection based on the transaction status -func (dbc *Conn) terminate(ctx context.Context, insideTxn bool, now time.Time) { +// getErrorMessageFromContextError gets the error message from context error. +func (dbc *Conn) getErrorMessageFromContextError(ctx context.Context) string { var errMsg string switch { case errors.Is(ctx.Err(), context.DeadlineExceeded): @@ -211,6 +211,12 @@ func (dbc *Conn) terminate(ctx context.Context, insideTxn bool, now time.Time) { default: errMsg = ctx.Err().Error() } + return errMsg +} + +// terminate kills the query or connection based on the transaction status +func (dbc *Conn) terminate(ctx context.Context, insideTxn bool, now time.Time) { + errMsg := dbc.getErrorMessageFromContextError(ctx) if insideTxn { // we can't safely kill a query in a transaction, we need to kill the connection _ = dbc.Kill(errMsg, time.Since(now)) @@ -229,7 +235,7 @@ func (dbc *Conn) FetchNext(ctx context.Context, maxrows int, wantfields bool) (* // Check if the context is already past its deadline before // trying to fetch the next result. if err := ctx.Err(); err != nil { - return nil, fmt.Errorf("%v before reading next result set", err) + return nil, vterrors.Errorf(vtrpcpb.Code_CANCELED, "%s before reading next result set", dbc.getErrorMessageFromContextError(ctx)) } res, _, _, err := dbc.conn.ReadQueryResult(maxrows, wantfields) if err != nil { diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go index 6f3c77de528..1d9104c4354 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go @@ -20,7 +20,6 @@ import ( "context" "errors" "fmt" - "strings" "sync/atomic" "testing" "time" @@ -76,18 +75,12 @@ func TestDBConnExec(t *testing.T) { if dbConn != nil { defer dbConn.Close() } - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) // Exec succeed, not asking for fields. result, err := dbConn.Exec(ctx, sql, 1, false) - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) expectedResult.Fields = nil - if !expectedResult.Equal(result) { - t.Errorf("Exec: %v, want %v", expectedResult, result) - } + require.True(t, expectedResult.Equal(result)) compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts()) @@ -100,10 +93,8 @@ func TestDBConnExec(t *testing.T) { Query: "", }) _, err = dbConn.Exec(ctx, sql, 1, false) - want := "connection fail" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Exec: %v, want %s", err, want) - } + require.Error(t, err) + require.ErrorContains(t, err, "connection fail") // The client side error triggers a retry in exec. compareTimingCounts(t, "PoolTest.Exec", 2, startCounts, mysqlTimings.Counts()) @@ -114,10 +105,8 @@ func TestDBConnExec(t *testing.T) { // This time the initial query fails as does the reconnect attempt. db.EnableConnFail() _, err = dbConn.Exec(ctx, sql, 1, false) - want = "packet read failed" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Exec: %v, want %s", err, want) - } + require.Error(t, err) + require.ErrorContains(t, err, "packet read failed") db.DisableConnFail() compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts()) @@ -150,14 +139,10 @@ func TestDBConnExecLost(t *testing.T) { if dbConn != nil { defer dbConn.Close() } - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) // Exec succeed, not asking for fields. result, err := dbConn.Exec(ctx, sql, 1, false) - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) expectedResult.Fields = nil if !expectedResult.Equal(result) { t.Errorf("Exec: %v, want %v", expectedResult, result) @@ -173,10 +158,8 @@ func TestDBConnExecLost(t *testing.T) { Query: "", }) _, err = dbConn.Exec(ctx, sql, 1, false) - want := "Lost connection to MySQL server during query" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Exec: %v, want %s", err, want) - } + require.Error(t, err) + require.ErrorContains(t, err, "Lost connection to MySQL server during query") // Should *not* see a retry, so only increment by 1 compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts()) @@ -212,15 +195,11 @@ func TestDBConnDeadline(t *testing.T) { if dbConn != nil { defer dbConn.Close() } - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) _, err = dbConn.Exec(ctx, sql, 1, false) - want := "context deadline exceeded before execution started" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Exec: %v, want %s", err, want) - } + require.Error(t, err) + require.ErrorContains(t, err, "(errno 3024) (sqlstate HY000): Query execution was interrupted, maximum statement execution time exceeded before execution started") compareTimingCounts(t, "PoolTest.Exec", 0, startCounts, mysqlTimings.Counts()) @@ -230,9 +209,7 @@ func TestDBConnDeadline(t *testing.T) { defer cancel() result, err := dbConn.Exec(ctx, sql, 1, false) - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) expectedResult.Fields = nil if !expectedResult.Equal(result) { t.Errorf("Exec: %v, want %v", expectedResult, result) @@ -244,9 +221,7 @@ func TestDBConnDeadline(t *testing.T) { // Test with just the Background context (with no deadline) result, err = dbConn.Exec(context.Background(), sql, 1, false) - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) expectedResult.Fields = nil if !expectedResult.Equal(result) { t.Errorf("Exec: %v, want %v", expectedResult, result) @@ -266,18 +241,14 @@ func TestDBConnKill(t *testing.T) { if dbConn != nil { defer dbConn.Close() } - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) query := fmt.Sprintf("kill %d", dbConn.ID()) db.AddQuery(query, &sqltypes.Result{}) // Kill failed because we are not able to connect to the database db.EnableConnFail() err = dbConn.Kill("test kill", 0) - want := "errno 2013" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Exec: %v, want %s", err, want) - } + require.Error(t, err) + require.ErrorContains(t, err, "errno 2013") db.DisableConnFail() // Kill succeed @@ -294,10 +265,8 @@ func TestDBConnKill(t *testing.T) { // Kill failed because "kill query_id" failed db.AddRejectedQuery(newKillQuery, errors.New("rejected")) err = dbConn.Kill("test kill", 0) - want = "rejected" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Exec: %v, want %s", err, want) - } + require.Error(t, err) + require.ErrorContains(t, err, "rejected") } func TestDBKillWithContext(t *testing.T) { @@ -479,18 +448,17 @@ func TestDBNoPoolConnKill(t *testing.T) { if dbConn != nil { defer dbConn.Close() } - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) query := fmt.Sprintf("kill %d", dbConn.ID()) db.AddQuery(query, &sqltypes.Result{}) // Kill failed because we are not able to connect to the database db.EnableConnFail() err = dbConn.Kill("test kill", 0) - want := "errno 2013" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Exec: %v, want %s", err, want) - } + require.Error(t, err) + var sqlErr *sqlerror.SQLError + isSqlErr := errors.As(sqlerror.NewSQLErrorFromError(err), &sqlErr) + require.True(t, isSqlErr) + require.EqualValues(t, sqlerror.CRServerLost, sqlErr.Number()) db.DisableConnFail() // Kill succeed @@ -507,10 +475,8 @@ func TestDBNoPoolConnKill(t *testing.T) { // Kill failed because "kill query_id" failed db.AddRejectedQuery(newKillQuery, errors.New("rejected")) err = dbConn.Kill("test kill", 0) - want = "rejected" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Exec: %v, want %s", err, want) - } + require.Error(t, err) + require.ErrorContains(t, err, "rejected") } func TestDBConnStream(t *testing.T) { @@ -536,9 +502,7 @@ func TestDBConnStream(t *testing.T) { if dbConn != nil { defer dbConn.Close() } - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } + require.NoError(t, err) var result sqltypes.Result err = dbConn.Stream( ctx, sql, func(r *sqltypes.Result) error { @@ -552,12 +516,8 @@ func TestDBConnStream(t *testing.T) { return nil }, alloc, 10, querypb.ExecuteOptions_ALL) - if err != nil { - t.Fatalf("should not get an error, err: %v", err) - } - if !expectedResult.Equal(&result) { - t.Errorf("Exec: %v, want %v", expectedResult, &result) - } + require.NoError(t, err) + require.True(t, expectedResult.Equal(&result)) // Stream fail db.Close() dbConn.Close() @@ -569,10 +529,8 @@ func TestDBConnStream(t *testing.T) { }, 10, querypb.ExecuteOptions_ALL) db.DisableConnFail() - want := "no such file or directory (errno 2002)" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Error: '%v', must contain '%s'", err, want) - } + require.Error(t, err) + require.ErrorContains(t, err, "no such file or directory (errno 2002)") } // TestDBConnKillCall tests that direct Kill method calls work as expected.