Skip to content

Commit

Permalink
go/cmd: Audit and fix context.Background() usage (vitessio#15928)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored May 15, 2024
1 parent 8639de2 commit 61061cb
Show file tree
Hide file tree
Showing 20 changed files with 79 additions and 74 deletions.
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func commandInit(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), initArgs.WaitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), initArgs.WaitTime)
defer cancel()
if err := mysqld.Init(ctx, cnf, initArgs.InitDbSQLFile); err != nil {
return fmt.Errorf("failed init mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func commandShutdown(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), shutdownArgs.WaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), shutdownArgs.WaitTime+10*time.Second)
defer cancel()
if err := mysqld.Shutdown(ctx, cnf, true, shutdownArgs.WaitTime); err != nil {
return fmt.Errorf("failed shutdown mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func commandStart(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), startArgs.WaitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), startArgs.WaitTime)
defer cancel()
if err := mysqld.Start(ctx, cnf, startArgs.MySQLdArgs...); err != nil {
return fmt.Errorf("failed start mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/teardown.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func commandTeardown(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), teardownArgs.WaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), teardownArgs.WaitTime+10*time.Second)
defer cancel()
if err := mysqld.Teardown(ctx, cnf, teardownArgs.Force, teardownArgs.WaitTime); err != nil {
return fmt.Errorf("failed teardown mysql (forced? %v): %v", teardownArgs.Force, err)
Expand Down
4 changes: 2 additions & 2 deletions go/cmd/mysqlctld/cli/mysqlctld.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func run(cmd *cobra.Command, args []string) error {
}

// Start or Init mysqld as needed.
ctx, cancel := context.WithTimeout(context.Background(), waitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), waitTime)
mycnfFile := mysqlctl.MycnfFile(tabletUID)
if _, statErr := os.Stat(mycnfFile); os.IsNotExist(statErr) {
// Generate my.cnf from scratch and use it to find mysqld.
Expand Down Expand Up @@ -167,7 +167,7 @@ func run(cmd *cobra.Command, args []string) error {
// Take mysqld down with us on SIGTERM before entering lame duck.
servenv.OnTermSync(func() {
log.Infof("mysqlctl received SIGTERM, shutting down mysqld first")
ctx, cancel := context.WithTimeout(context.Background(), shutdownWaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), shutdownWaitTime+10*time.Second)
defer cancel()
if err := mysqld.Shutdown(ctx, cnf, true, shutdownWaitTime); err != nil {
log.Errorf("failed to shutdown mysqld: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/topo2topo/cli/topo2topo.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func run(cmd *cobra.Command, args []string) error {
return fmt.Errorf("Cannot open 'to' topo %v: %w", toImplementation, err)
}

ctx := context.Background()
ctx := cmd.Context()

if compare {
return compareTopos(ctx, fromTS, toTS)
Expand Down
3 changes: 1 addition & 2 deletions go/cmd/vtadmin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package main

import (
"context"
"flag"
"io"
"time"
Expand Down Expand Up @@ -97,7 +96,7 @@ func startTracing(cmd *cobra.Command) {
}

func run(cmd *cobra.Command, args []string) {
bootSpan, ctx := trace.NewSpan(context.Background(), "vtadmin.boot")
bootSpan, ctx := trace.NewSpan(cmd.Context(), "vtadmin.boot")
defer bootSpan.Finish()

configs := clusterFileConfig.Combine(defaultClusterConfig, clusterConfigs)
Expand Down
14 changes: 7 additions & 7 deletions go/cmd/vtbackup/cli/vtbackup.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ func init() {
collationEnv = collations.NewEnvironment(servenv.MySQLServerVersion())
}

func run(_ *cobra.Command, args []string) error {
func run(cc *cobra.Command, args []string) error {
servenv.Init()

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(cc.Context())
servenv.OnClose(func() {
cancel()
})
Expand Down Expand Up @@ -282,7 +282,7 @@ func run(_ *cobra.Command, args []string) error {
return fmt.Errorf("Can't take backup: %w", err)
}
if doBackup {
if err := takeBackup(ctx, topoServer, backupStorage); err != nil {
if err := takeBackup(ctx, cc.Context(), topoServer, backupStorage); err != nil {
return fmt.Errorf("Failed to take backup: %w", err)
}
}
Expand All @@ -304,7 +304,7 @@ func run(_ *cobra.Command, args []string) error {
return nil
}

func takeBackup(ctx context.Context, topoServer *topo.Server, backupStorage backupstorage.BackupStorage) error {
func takeBackup(ctx, backgroundCtx context.Context, topoServer *topo.Server, backupStorage backupstorage.BackupStorage) error {
// This is an imaginary tablet alias. The value doesn't matter for anything,
// except that we generate a random UID to ensure the target backup
// directory is unique if multiple vtbackup instances are launched for the
Expand Down Expand Up @@ -344,9 +344,9 @@ func takeBackup(ctx context.Context, topoServer *topo.Server, backupStorage back
deprecatedDurationByPhase.Set("InitMySQLd", int64(time.Since(initMysqldAt).Seconds()))
// Shut down mysqld when we're done.
defer func() {
// Be careful not to use the original context, because we don't want to
// skip shutdown just because we timed out waiting for other things.
mysqlShutdownCtx, mysqlShutdownCancel := context.WithTimeout(context.Background(), mysqlShutdownTimeout+10*time.Second)
// Be careful use the background context, not the init one, because we don't want to
// skip shutdown just because we timed out waiting for init.
mysqlShutdownCtx, mysqlShutdownCancel := context.WithTimeout(backgroundCtx, mysqlShutdownTimeout+10*time.Second)
defer mysqlShutdownCancel()
if err := mysqld.Shutdown(mysqlShutdownCtx, mycnf, false, mysqlShutdownTimeout); err != nil {
log.Errorf("failed to shutdown mysqld: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtbench/cli/vtbench.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func run(cmd *cobra.Command, args []string) error {

b := vtbench.NewBench(threads, count, connParams, sql)

ctx, cancel := context.WithTimeout(context.Background(), deadline)
ctx, cancel := context.WithTimeout(cmd.Context(), deadline)
defer cancel()

fmt.Printf("Initializing test with %s protocol / %d threads / %d iterations\n",
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtclient/cli/vtclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func _run(cmd *cobra.Command, args []string) (*results, error) {

log.Infof("Sending the query...")

ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(cmd.Context(), timeout)
defer cancel()
return execMulti(ctx, db, cmd.Flags().Arg(0))
}
Expand Down
2 changes: 2 additions & 0 deletions go/cmd/vtclient/cli/vtclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package cli

import (
"context"
"fmt"
"os"
"strings"
Expand Down Expand Up @@ -129,6 +130,7 @@ func TestVtclient(t *testing.T) {
err := Main.ParseFlags(args)
require.NoError(t, err)

Main.SetContext(context.Background())
results, err := _run(Main, args)
if q.errMsg != "" {
if got, want := err.Error(), q.errMsg; !strings.Contains(got, want) {
Expand Down
51 changes: 23 additions & 28 deletions go/cmd/vtcombo/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ func init() {
srvTopoCounts = stats.NewCountersWithSingleLabel("ResilientSrvTopoServer", "Resilient srvtopo server operations", "type")
}

func startMysqld(uid uint32) (mysqld *mysqlctl.Mysqld, cnf *mysqlctl.Mycnf, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
func startMysqld(ctx context.Context, uid uint32) (mysqld *mysqlctl.Mysqld, cnf *mysqlctl.Mycnf, err error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

mycnfFile := mysqlctl.MycnfFile(uid)
Expand Down Expand Up @@ -189,17 +189,20 @@ func run(cmd *cobra.Command, args []string) (err error) {
cmd.Flags().Set("log_dir", "$VTDATAROOT/tmp")
}

ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
if externalTopoServer {
// Open topo server based on the command line flags defined at topo/server.go
// do not create cell info as it should be done by whoever sets up the external topo server
ts = topo.Open()
} else {
// Create topo server. We use a 'memorytopo' implementation.
ts = memorytopo.NewServer(context.Background(), tpb.Cells...)
ts = memorytopo.NewServer(ctx, tpb.Cells...)
}
defer ts.Close()

// attempt to load any routing rules specified by tpb
if err := vtcombo.InitRoutingRules(context.Background(), ts, tpb.GetRoutingRules()); err != nil {
if err := vtcombo.InitRoutingRules(ctx, ts, tpb.GetRoutingRules()); err != nil {
return fmt.Errorf("Failed to load routing rules: %w", err)
}

Expand All @@ -212,17 +215,17 @@ func run(cmd *cobra.Command, args []string) (err error) {
)

if startMysql {
mysqld.Mysqld, cnf, err = startMysqld(1)
mysqld.Mysqld, cnf, err = startMysqld(ctx, 1)
if err != nil {
return err
}
servenv.OnClose(func() {
ctx, cancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
shutdownCtx, shutdownCancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer shutdownCancel()
mysqld.Shutdown(shutdownCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
})
// We want to ensure we can write to this database
mysqld.SetReadOnly(cmd.Context(), false)
mysqld.SetReadOnly(ctx, false)

} else {
dbconfigs.GlobalDBConfigs.InitWithSocket("", env.CollationEnv())
Expand All @@ -241,9 +244,9 @@ func run(cmd *cobra.Command, args []string) (err error) {
if err != nil {
// ensure we start mysql in the event we fail here
if startMysql {
ctx, cancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
startCtx, startCancel := context.WithTimeout(ctx, mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer startCancel()
mysqld.Shutdown(startCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
}

return fmt.Errorf("initTabletMapProto failed: %w", err)
Expand Down Expand Up @@ -287,20 +290,21 @@ func run(cmd *cobra.Command, args []string) (err error) {

// Now that we have fully initialized the tablets, rebuild the keyspace graph.
for _, ks := range tpb.Keyspaces {
err := topotools.RebuildKeyspace(context.Background(), logutil.NewConsoleLogger(), ts, ks.GetName(), tpb.Cells, false)
err := topotools.RebuildKeyspace(cmd.Context(), logutil.NewConsoleLogger(), ts, ks.GetName(), tpb.Cells, false)
if err != nil {
if startMysql {
ctx, cancel := context.WithTimeout(context.Background(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
shutdownCtx, shutdownCancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer shutdownCancel()
mysqld.Shutdown(shutdownCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
}

return fmt.Errorf("Couldn't build srv keyspace for (%v: %v). Got error: %w", ks, tpb.Cells, err)
}
}

// vtgate configuration and init
resilientServer = srvtopo.NewResilientServer(context.Background(), ts, srvTopoCounts)

resilientServer = srvtopo.NewResilientServer(ctx, ts, srvTopoCounts)

tabletTypes := make([]topodatapb.TabletType, 0, 1)
if len(tabletTypesToWait) != 0 {
Expand All @@ -324,7 +328,7 @@ func run(cmd *cobra.Command, args []string) (err error) {
vtgate.QueryzHandler = "/debug/vtgate/queryz"

// pass nil for healthcheck, it will get created
vtg := vtgate.Init(context.Background(), env, nil, resilientServer, tpb.Cells[0], tabletTypes, plannerVersion)
vtg := vtgate.Init(ctx, env, nil, resilientServer, tpb.Cells[0], tabletTypes, plannerVersion)

// vtctld configuration and init
err = vtctld.InitVtctld(env, ts)
Expand All @@ -333,22 +337,13 @@ func run(cmd *cobra.Command, args []string) (err error) {
}

if vschemaPersistenceDir != "" && !externalTopoServer {
startVschemaWatcher(vschemaPersistenceDir, tpb.Keyspaces, ts)
startVschemaWatcher(ctx, vschemaPersistenceDir, ts)
}

servenv.OnRun(func() {
addStatusParts(vtg)
})

servenv.OnTerm(func() {
log.Error("Terminating")
// FIXME(alainjobart): stop vtgate
})
servenv.OnClose(func() {
// We will still use the topo server during lameduck period
// to update our state, so closing it in OnClose()
ts.Close()
})
servenv.RunDefault()

return nil
Expand Down
15 changes: 7 additions & 8 deletions go/cmd/vtcombo/cli/vschema_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,27 @@ import (
"vitess.io/vitess/go/vt/vtgate/vindexes"

vschemapb "vitess.io/vitess/go/vt/proto/vschema"
vttestpb "vitess.io/vitess/go/vt/proto/vttest"
)

func startVschemaWatcher(vschemaPersistenceDir string, keyspaces []*vttestpb.Keyspace, ts *topo.Server) {
func startVschemaWatcher(ctx context.Context, vschemaPersistenceDir string, ts *topo.Server) {
// Create the directory if it doesn't exist.
if err := createDirectoryIfNotExists(vschemaPersistenceDir); err != nil {
log.Fatalf("Unable to create vschema persistence directory %v: %v", vschemaPersistenceDir, err)
}

// If there are keyspace files, load them.
loadKeyspacesFromDir(vschemaPersistenceDir, keyspaces, ts)
loadKeyspacesFromDir(ctx, vschemaPersistenceDir, ts)

// Rebuild the SrvVSchema object in case we loaded vschema from file
if err := ts.RebuildSrvVSchema(context.Background(), tpb.Cells); err != nil {
if err := ts.RebuildSrvVSchema(ctx, tpb.Cells); err != nil {
log.Fatalf("RebuildSrvVSchema failed: %v", err)
}

// Now watch for changes in the SrvVSchema object and persist them to disk.
go watchSrvVSchema(context.Background(), ts, tpb.Cells[0])
go watchSrvVSchema(ctx, ts, tpb.Cells[0])
}

func loadKeyspacesFromDir(dir string, keyspaces []*vttestpb.Keyspace, ts *topo.Server) {
func loadKeyspacesFromDir(ctx context.Context, dir string, ts *topo.Server) {
for _, ks := range tpb.Keyspaces {
ksFile := path.Join(dir, ks.Name+".json")
if _, err := os.Stat(ksFile); err == nil {
Expand All @@ -67,14 +66,14 @@ func loadKeyspacesFromDir(dir string, keyspaces []*vttestpb.Keyspace, ts *topo.S
if err != nil {
log.Fatalf("Invalid keyspace definition: %v", err)
}
ts.SaveVSchema(context.Background(), ks.Name, keyspace)
ts.SaveVSchema(ctx, ks.Name, keyspace)
log.Infof("Loaded keyspace %v from %v\n", ks.Name, ksFile)
}
}
}

func watchSrvVSchema(ctx context.Context, ts *topo.Server, cell string) {
data, ch, err := ts.WatchSrvVSchema(context.Background(), tpb.Cells[0])
data, ch, err := ts.WatchSrvVSchema(ctx, tpb.Cells[0])
if err != nil {
log.Fatalf("WatchSrvVSchema failed: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtctld/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func run(cmd *cobra.Command, args []string) error {
vtctld.RegisterDebugHealthHandler(ts)

// Start schema manager service.
initSchema()
initSchema(cmd.Context())

// And run the server.
servenv.RunDefault()
Expand Down
3 changes: 1 addition & 2 deletions go/cmd/vtctld/cli/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func init() {
Main.Flags().DurationVar(&schemaChangeReplicasTimeout, "schema_change_replicas_timeout", schemaChangeReplicasTimeout, "How long to wait for replicas to receive a schema change.")
}

func initSchema() {
func initSchema(ctx context.Context) {
// Start schema manager service if needed.
if schemaChangeDir != "" {
interval := schemaChangeCheckInterval
Expand All @@ -70,7 +70,6 @@ func initSchema() {
log.Errorf("failed to get controller, error: %v", err)
return
}
ctx := context.Background()
wr := wrangler.New(env, logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient())
_, err = schemamanager.Run(
ctx,
Expand Down
6 changes: 3 additions & 3 deletions go/cmd/vtctldclient/command/legacy_shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var (
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
cli.FinishedParsing(cmd)
return runLegacyCommand(args)
return runLegacyCommand(cmd.Context(), args)
},
Long: strings.TrimSpace(`
LegacyVtctlCommand uses the legacy vtctl grpc client to make an ExecuteVtctlCommand
Expand Down Expand Up @@ -76,11 +76,11 @@ LegacyVtctlCommand -- AddCellInfo --server_address "localhost:5678" --root "/vit
}
)

func runLegacyCommand(args []string) error {
func runLegacyCommand(ctx context.Context, args []string) error {
// Duplicated (mostly) from go/cmd/vtctlclient/main.go.
logger := logutil.NewConsoleLogger()

ctx, cancel := context.WithTimeout(context.Background(), actionTimeout)
ctx, cancel := context.WithTimeout(ctx, actionTimeout)
defer cancel()

err := vtctlclient.RunCommandAndWait(ctx, server, args, func(e *logutilpb.Event) {
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtctldclient/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ connect directly to the topo server(s).`, useInternalVtctld),
client, err = getClientForCommand(cmd)
ctx := cmd.Context()
if ctx == nil {
ctx = context.Background()
ctx = cmd.Context()
}
commandCtx, commandCancel = context.WithTimeout(ctx, actionTimeout)
if compactOutput {
Expand Down
Loading

0 comments on commit 61061cb

Please sign in to comment.