diff --git a/go/cmd/mysqlctl/command/init.go b/go/cmd/mysqlctl/command/init.go index 14d8e5f6d29..afaf1c566df 100644 --- a/go/cmd/mysqlctl/command/init.go +++ b/go/cmd/mysqlctl/command/init.go @@ -55,7 +55,7 @@ func commandInit(cmd *cobra.Command, args []string) error { } defer mysqld.Close() - ctx, cancel := context.WithTimeout(context.Background(), initArgs.WaitTime) + ctx, cancel := context.WithTimeout(cmd.Context(), initArgs.WaitTime) defer cancel() if err := mysqld.Init(ctx, cnf, initArgs.InitDbSQLFile); err != nil { return fmt.Errorf("failed init mysql: %v", err) diff --git a/go/cmd/mysqlctl/command/shutdown.go b/go/cmd/mysqlctl/command/shutdown.go index 321d4a9b35f..30e0c8c0f8e 100644 --- a/go/cmd/mysqlctl/command/shutdown.go +++ b/go/cmd/mysqlctl/command/shutdown.go @@ -50,7 +50,7 @@ func commandShutdown(cmd *cobra.Command, args []string) error { } defer mysqld.Close() - ctx, cancel := context.WithTimeout(context.Background(), shutdownArgs.WaitTime+10*time.Second) + ctx, cancel := context.WithTimeout(cmd.Context(), shutdownArgs.WaitTime+10*time.Second) defer cancel() if err := mysqld.Shutdown(ctx, cnf, true, shutdownArgs.WaitTime); err != nil { return fmt.Errorf("failed shutdown mysql: %v", err) diff --git a/go/cmd/mysqlctl/command/start.go b/go/cmd/mysqlctl/command/start.go index aef404d0a8e..aff8d723a8b 100644 --- a/go/cmd/mysqlctl/command/start.go +++ b/go/cmd/mysqlctl/command/start.go @@ -51,7 +51,7 @@ func commandStart(cmd *cobra.Command, args []string) error { } defer mysqld.Close() - ctx, cancel := context.WithTimeout(context.Background(), startArgs.WaitTime) + ctx, cancel := context.WithTimeout(cmd.Context(), startArgs.WaitTime) defer cancel() if err := mysqld.Start(ctx, cnf, startArgs.MySQLdArgs...); err != nil { return fmt.Errorf("failed start mysql: %v", err) diff --git a/go/cmd/mysqlctl/command/teardown.go b/go/cmd/mysqlctl/command/teardown.go index 89d7b3b5f6d..3e7e7bfd0ef 100644 --- a/go/cmd/mysqlctl/command/teardown.go +++ b/go/cmd/mysqlctl/command/teardown.go @@ -53,7 +53,7 @@ func commandTeardown(cmd *cobra.Command, args []string) error { } defer mysqld.Close() - ctx, cancel := context.WithTimeout(context.Background(), teardownArgs.WaitTime+10*time.Second) + ctx, cancel := context.WithTimeout(cmd.Context(), teardownArgs.WaitTime+10*time.Second) defer cancel() if err := mysqld.Teardown(ctx, cnf, teardownArgs.Force, teardownArgs.WaitTime); err != nil { return fmt.Errorf("failed teardown mysql (forced? %v): %v", teardownArgs.Force, err) diff --git a/go/cmd/mysqlctld/cli/mysqlctld.go b/go/cmd/mysqlctld/cli/mysqlctld.go index 8dacf8d9d9f..ee3fe241440 100644 --- a/go/cmd/mysqlctld/cli/mysqlctld.go +++ b/go/cmd/mysqlctld/cli/mysqlctld.go @@ -114,7 +114,7 @@ func run(cmd *cobra.Command, args []string) error { } // Start or Init mysqld as needed. - ctx, cancel := context.WithTimeout(context.Background(), waitTime) + ctx, cancel := context.WithTimeout(cmd.Context(), waitTime) mycnfFile := mysqlctl.MycnfFile(tabletUID) if _, statErr := os.Stat(mycnfFile); os.IsNotExist(statErr) { // Generate my.cnf from scratch and use it to find mysqld. @@ -167,7 +167,7 @@ func run(cmd *cobra.Command, args []string) error { // Take mysqld down with us on SIGTERM before entering lame duck. servenv.OnTermSync(func() { log.Infof("mysqlctl received SIGTERM, shutting down mysqld first") - ctx, cancel := context.WithTimeout(context.Background(), shutdownWaitTime+10*time.Second) + ctx, cancel := context.WithTimeout(cmd.Context(), shutdownWaitTime+10*time.Second) defer cancel() if err := mysqld.Shutdown(ctx, cnf, true, shutdownWaitTime); err != nil { log.Errorf("failed to shutdown mysqld: %v", err) diff --git a/go/cmd/topo2topo/cli/topo2topo.go b/go/cmd/topo2topo/cli/topo2topo.go index f6f69e08eda..13539d97629 100644 --- a/go/cmd/topo2topo/cli/topo2topo.go +++ b/go/cmd/topo2topo/cli/topo2topo.go @@ -90,7 +90,7 @@ func run(cmd *cobra.Command, args []string) error { return fmt.Errorf("Cannot open 'to' topo %v: %w", toImplementation, err) } - ctx := context.Background() + ctx := cmd.Context() if compare { return compareTopos(ctx, fromTS, toTS) diff --git a/go/cmd/vtadmin/main.go b/go/cmd/vtadmin/main.go index 6cc3b9065b5..ad93d058c00 100644 --- a/go/cmd/vtadmin/main.go +++ b/go/cmd/vtadmin/main.go @@ -17,7 +17,6 @@ limitations under the License. package main import ( - "context" "flag" "io" "time" @@ -97,7 +96,7 @@ func startTracing(cmd *cobra.Command) { } func run(cmd *cobra.Command, args []string) { - bootSpan, ctx := trace.NewSpan(context.Background(), "vtadmin.boot") + bootSpan, ctx := trace.NewSpan(cmd.Context(), "vtadmin.boot") defer bootSpan.Finish() configs := clusterFileConfig.Combine(defaultClusterConfig, clusterConfigs) diff --git a/go/cmd/vtbackup/cli/vtbackup.go b/go/cmd/vtbackup/cli/vtbackup.go index c9f55333614..1b61c886ae7 100644 --- a/go/cmd/vtbackup/cli/vtbackup.go +++ b/go/cmd/vtbackup/cli/vtbackup.go @@ -221,10 +221,10 @@ func init() { collationEnv = collations.NewEnvironment(servenv.MySQLServerVersion()) } -func run(_ *cobra.Command, args []string) error { +func run(cc *cobra.Command, args []string) error { servenv.Init() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(cc.Context()) servenv.OnClose(func() { cancel() }) @@ -282,7 +282,7 @@ func run(_ *cobra.Command, args []string) error { return fmt.Errorf("Can't take backup: %w", err) } if doBackup { - if err := takeBackup(ctx, topoServer, backupStorage); err != nil { + if err := takeBackup(ctx, cc.Context(), topoServer, backupStorage); err != nil { return fmt.Errorf("Failed to take backup: %w", err) } } @@ -304,7 +304,7 @@ func run(_ *cobra.Command, args []string) error { return nil } -func takeBackup(ctx context.Context, topoServer *topo.Server, backupStorage backupstorage.BackupStorage) error { +func takeBackup(ctx, backgroundCtx context.Context, topoServer *topo.Server, backupStorage backupstorage.BackupStorage) error { // This is an imaginary tablet alias. The value doesn't matter for anything, // except that we generate a random UID to ensure the target backup // directory is unique if multiple vtbackup instances are launched for the @@ -344,9 +344,9 @@ func takeBackup(ctx context.Context, topoServer *topo.Server, backupStorage back deprecatedDurationByPhase.Set("InitMySQLd", int64(time.Since(initMysqldAt).Seconds())) // Shut down mysqld when we're done. defer func() { - // Be careful not to use the original context, because we don't want to - // skip shutdown just because we timed out waiting for other things. - mysqlShutdownCtx, mysqlShutdownCancel := context.WithTimeout(context.Background(), mysqlShutdownTimeout+10*time.Second) + // Be careful use the background context, not the init one, because we don't want to + // skip shutdown just because we timed out waiting for init. + mysqlShutdownCtx, mysqlShutdownCancel := context.WithTimeout(backgroundCtx, mysqlShutdownTimeout+10*time.Second) defer mysqlShutdownCancel() if err := mysqld.Shutdown(mysqlShutdownCtx, mycnf, false, mysqlShutdownTimeout); err != nil { log.Errorf("failed to shutdown mysqld: %v", err) diff --git a/go/cmd/vtbench/cli/vtbench.go b/go/cmd/vtbench/cli/vtbench.go index 69b866bb60d..e36f06cb69e 100644 --- a/go/cmd/vtbench/cli/vtbench.go +++ b/go/cmd/vtbench/cli/vtbench.go @@ -212,7 +212,7 @@ func run(cmd *cobra.Command, args []string) error { b := vtbench.NewBench(threads, count, connParams, sql) - ctx, cancel := context.WithTimeout(context.Background(), deadline) + ctx, cancel := context.WithTimeout(cmd.Context(), deadline) defer cancel() fmt.Printf("Initializing test with %s protocol / %d threads / %d iterations\n", diff --git a/go/cmd/vtclient/cli/vtclient.go b/go/cmd/vtclient/cli/vtclient.go index 44b47c38dc7..e8bcd9b7ff2 100644 --- a/go/cmd/vtclient/cli/vtclient.go +++ b/go/cmd/vtclient/cli/vtclient.go @@ -197,7 +197,7 @@ func _run(cmd *cobra.Command, args []string) (*results, error) { log.Infof("Sending the query...") - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(cmd.Context(), timeout) defer cancel() return execMulti(ctx, db, cmd.Flags().Arg(0)) } diff --git a/go/cmd/vtclient/cli/vtclient_test.go b/go/cmd/vtclient/cli/vtclient_test.go index a323ab5c63a..bf0c1206167 100644 --- a/go/cmd/vtclient/cli/vtclient_test.go +++ b/go/cmd/vtclient/cli/vtclient_test.go @@ -17,6 +17,7 @@ limitations under the License. package cli import ( + "context" "fmt" "os" "strings" @@ -129,6 +130,7 @@ func TestVtclient(t *testing.T) { err := Main.ParseFlags(args) require.NoError(t, err) + Main.SetContext(context.Background()) results, err := _run(Main, args) if q.errMsg != "" { if got, want := err.Error(), q.errMsg; !strings.Contains(got, want) { diff --git a/go/cmd/vtcombo/cli/main.go b/go/cmd/vtcombo/cli/main.go index 88b0bd933a3..189441594bb 100644 --- a/go/cmd/vtcombo/cli/main.go +++ b/go/cmd/vtcombo/cli/main.go @@ -138,8 +138,8 @@ func init() { srvTopoCounts = stats.NewCountersWithSingleLabel("ResilientSrvTopoServer", "Resilient srvtopo server operations", "type") } -func startMysqld(uid uint32) (mysqld *mysqlctl.Mysqld, cnf *mysqlctl.Mycnf, err error) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +func startMysqld(ctx context.Context, uid uint32) (mysqld *mysqlctl.Mysqld, cnf *mysqlctl.Mycnf, err error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() mycnfFile := mysqlctl.MycnfFile(uid) @@ -189,17 +189,20 @@ func run(cmd *cobra.Command, args []string) (err error) { cmd.Flags().Set("log_dir", "$VTDATAROOT/tmp") } + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() if externalTopoServer { // Open topo server based on the command line flags defined at topo/server.go // do not create cell info as it should be done by whoever sets up the external topo server ts = topo.Open() } else { // Create topo server. We use a 'memorytopo' implementation. - ts = memorytopo.NewServer(context.Background(), tpb.Cells...) + ts = memorytopo.NewServer(ctx, tpb.Cells...) } + defer ts.Close() // attempt to load any routing rules specified by tpb - if err := vtcombo.InitRoutingRules(context.Background(), ts, tpb.GetRoutingRules()); err != nil { + if err := vtcombo.InitRoutingRules(ctx, ts, tpb.GetRoutingRules()); err != nil { return fmt.Errorf("Failed to load routing rules: %w", err) } @@ -212,17 +215,17 @@ func run(cmd *cobra.Command, args []string) (err error) { ) if startMysql { - mysqld.Mysqld, cnf, err = startMysqld(1) + mysqld.Mysqld, cnf, err = startMysqld(ctx, 1) if err != nil { return err } servenv.OnClose(func() { - ctx, cancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second) - defer cancel() - mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout) + shutdownCtx, shutdownCancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second) + defer shutdownCancel() + mysqld.Shutdown(shutdownCtx, cnf, true, mysqlctl.DefaultShutdownTimeout) }) // We want to ensure we can write to this database - mysqld.SetReadOnly(cmd.Context(), false) + mysqld.SetReadOnly(ctx, false) } else { dbconfigs.GlobalDBConfigs.InitWithSocket("", env.CollationEnv()) @@ -241,9 +244,9 @@ func run(cmd *cobra.Command, args []string) (err error) { if err != nil { // ensure we start mysql in the event we fail here if startMysql { - ctx, cancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second) - defer cancel() - mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout) + startCtx, startCancel := context.WithTimeout(ctx, mysqlctl.DefaultShutdownTimeout+10*time.Second) + defer startCancel() + mysqld.Shutdown(startCtx, cnf, true, mysqlctl.DefaultShutdownTimeout) } return fmt.Errorf("initTabletMapProto failed: %w", err) @@ -287,12 +290,12 @@ func run(cmd *cobra.Command, args []string) (err error) { // Now that we have fully initialized the tablets, rebuild the keyspace graph. for _, ks := range tpb.Keyspaces { - err := topotools.RebuildKeyspace(context.Background(), logutil.NewConsoleLogger(), ts, ks.GetName(), tpb.Cells, false) + err := topotools.RebuildKeyspace(cmd.Context(), logutil.NewConsoleLogger(), ts, ks.GetName(), tpb.Cells, false) if err != nil { if startMysql { - ctx, cancel := context.WithTimeout(context.Background(), mysqlctl.DefaultShutdownTimeout+10*time.Second) - defer cancel() - mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout) + shutdownCtx, shutdownCancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second) + defer shutdownCancel() + mysqld.Shutdown(shutdownCtx, cnf, true, mysqlctl.DefaultShutdownTimeout) } return fmt.Errorf("Couldn't build srv keyspace for (%v: %v). Got error: %w", ks, tpb.Cells, err) @@ -300,7 +303,8 @@ func run(cmd *cobra.Command, args []string) (err error) { } // vtgate configuration and init - resilientServer = srvtopo.NewResilientServer(context.Background(), ts, srvTopoCounts) + + resilientServer = srvtopo.NewResilientServer(ctx, ts, srvTopoCounts) tabletTypes := make([]topodatapb.TabletType, 0, 1) if len(tabletTypesToWait) != 0 { @@ -324,7 +328,7 @@ func run(cmd *cobra.Command, args []string) (err error) { vtgate.QueryzHandler = "/debug/vtgate/queryz" // pass nil for healthcheck, it will get created - vtg := vtgate.Init(context.Background(), env, nil, resilientServer, tpb.Cells[0], tabletTypes, plannerVersion) + vtg := vtgate.Init(ctx, env, nil, resilientServer, tpb.Cells[0], tabletTypes, plannerVersion) // vtctld configuration and init err = vtctld.InitVtctld(env, ts) @@ -333,22 +337,13 @@ func run(cmd *cobra.Command, args []string) (err error) { } if vschemaPersistenceDir != "" && !externalTopoServer { - startVschemaWatcher(vschemaPersistenceDir, tpb.Keyspaces, ts) + startVschemaWatcher(ctx, vschemaPersistenceDir, ts) } servenv.OnRun(func() { addStatusParts(vtg) }) - servenv.OnTerm(func() { - log.Error("Terminating") - // FIXME(alainjobart): stop vtgate - }) - servenv.OnClose(func() { - // We will still use the topo server during lameduck period - // to update our state, so closing it in OnClose() - ts.Close() - }) servenv.RunDefault() return nil diff --git a/go/cmd/vtcombo/cli/vschema_watcher.go b/go/cmd/vtcombo/cli/vschema_watcher.go index 112a2a0730f..484c7736424 100644 --- a/go/cmd/vtcombo/cli/vschema_watcher.go +++ b/go/cmd/vtcombo/cli/vschema_watcher.go @@ -27,28 +27,27 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" vschemapb "vitess.io/vitess/go/vt/proto/vschema" - vttestpb "vitess.io/vitess/go/vt/proto/vttest" ) -func startVschemaWatcher(vschemaPersistenceDir string, keyspaces []*vttestpb.Keyspace, ts *topo.Server) { +func startVschemaWatcher(ctx context.Context, vschemaPersistenceDir string, ts *topo.Server) { // Create the directory if it doesn't exist. if err := createDirectoryIfNotExists(vschemaPersistenceDir); err != nil { log.Fatalf("Unable to create vschema persistence directory %v: %v", vschemaPersistenceDir, err) } // If there are keyspace files, load them. - loadKeyspacesFromDir(vschemaPersistenceDir, keyspaces, ts) + loadKeyspacesFromDir(ctx, vschemaPersistenceDir, ts) // Rebuild the SrvVSchema object in case we loaded vschema from file - if err := ts.RebuildSrvVSchema(context.Background(), tpb.Cells); err != nil { + if err := ts.RebuildSrvVSchema(ctx, tpb.Cells); err != nil { log.Fatalf("RebuildSrvVSchema failed: %v", err) } // Now watch for changes in the SrvVSchema object and persist them to disk. - go watchSrvVSchema(context.Background(), ts, tpb.Cells[0]) + go watchSrvVSchema(ctx, ts, tpb.Cells[0]) } -func loadKeyspacesFromDir(dir string, keyspaces []*vttestpb.Keyspace, ts *topo.Server) { +func loadKeyspacesFromDir(ctx context.Context, dir string, ts *topo.Server) { for _, ks := range tpb.Keyspaces { ksFile := path.Join(dir, ks.Name+".json") if _, err := os.Stat(ksFile); err == nil { @@ -67,14 +66,14 @@ func loadKeyspacesFromDir(dir string, keyspaces []*vttestpb.Keyspace, ts *topo.S if err != nil { log.Fatalf("Invalid keyspace definition: %v", err) } - ts.SaveVSchema(context.Background(), ks.Name, keyspace) + ts.SaveVSchema(ctx, ks.Name, keyspace) log.Infof("Loaded keyspace %v from %v\n", ks.Name, ksFile) } } } func watchSrvVSchema(ctx context.Context, ts *topo.Server, cell string) { - data, ch, err := ts.WatchSrvVSchema(context.Background(), tpb.Cells[0]) + data, ch, err := ts.WatchSrvVSchema(ctx, tpb.Cells[0]) if err != nil { log.Fatalf("WatchSrvVSchema failed: %v", err) } diff --git a/go/cmd/vtctld/cli/cli.go b/go/cmd/vtctld/cli/cli.go index 4f8c57b6b2f..8cf208a66f0 100644 --- a/go/cmd/vtctld/cli/cli.go +++ b/go/cmd/vtctld/cli/cli.go @@ -79,7 +79,7 @@ func run(cmd *cobra.Command, args []string) error { vtctld.RegisterDebugHealthHandler(ts) // Start schema manager service. - initSchema() + initSchema(cmd.Context()) // And run the server. servenv.RunDefault() diff --git a/go/cmd/vtctld/cli/schema.go b/go/cmd/vtctld/cli/schema.go index 9f1f9d06072..a330a23abe2 100644 --- a/go/cmd/vtctld/cli/schema.go +++ b/go/cmd/vtctld/cli/schema.go @@ -47,7 +47,7 @@ func init() { Main.Flags().DurationVar(&schemaChangeReplicasTimeout, "schema_change_replicas_timeout", schemaChangeReplicasTimeout, "How long to wait for replicas to receive a schema change.") } -func initSchema() { +func initSchema(ctx context.Context) { // Start schema manager service if needed. if schemaChangeDir != "" { interval := schemaChangeCheckInterval @@ -70,7 +70,6 @@ func initSchema() { log.Errorf("failed to get controller, error: %v", err) return } - ctx := context.Background() wr := wrangler.New(env, logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient()) _, err = schemamanager.Run( ctx, diff --git a/go/cmd/vtctldclient/command/legacy_shim.go b/go/cmd/vtctldclient/command/legacy_shim.go index 95c3ea2d688..d7594e1fdff 100644 --- a/go/cmd/vtctldclient/command/legacy_shim.go +++ b/go/cmd/vtctldclient/command/legacy_shim.go @@ -43,7 +43,7 @@ var ( Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { cli.FinishedParsing(cmd) - return runLegacyCommand(args) + return runLegacyCommand(cmd.Context(), args) }, Long: strings.TrimSpace(` LegacyVtctlCommand uses the legacy vtctl grpc client to make an ExecuteVtctlCommand @@ -76,11 +76,11 @@ LegacyVtctlCommand -- AddCellInfo --server_address "localhost:5678" --root "/vit } ) -func runLegacyCommand(args []string) error { +func runLegacyCommand(ctx context.Context, args []string) error { // Duplicated (mostly) from go/cmd/vtctlclient/main.go. logger := logutil.NewConsoleLogger() - ctx, cancel := context.WithTimeout(context.Background(), actionTimeout) + ctx, cancel := context.WithTimeout(ctx, actionTimeout) defer cancel() err := vtctlclient.RunCommandAndWait(ctx, server, args, func(e *logutilpb.Event) { diff --git a/go/cmd/vtctldclient/command/root.go b/go/cmd/vtctldclient/command/root.go index 048b0786cb4..3ebe019f94d 100644 --- a/go/cmd/vtctldclient/command/root.go +++ b/go/cmd/vtctldclient/command/root.go @@ -111,7 +111,7 @@ connect directly to the topo server(s).`, useInternalVtctld), client, err = getClientForCommand(cmd) ctx := cmd.Context() if ctx == nil { - ctx = context.Background() + ctx = cmd.Context() } commandCtx, commandCancel = context.WithTimeout(ctx, actionTimeout) if compactOutput { diff --git a/go/cmd/vtexplain/cli/vtexplain.go b/go/cmd/vtexplain/cli/vtexplain.go index fe17d9af47c..824b0c31f84 100644 --- a/go/cmd/vtexplain/cli/vtexplain.go +++ b/go/cmd/vtexplain/cli/vtexplain.go @@ -139,10 +139,10 @@ func run(cmd *cobra.Command, args []string) error { defer logutil.Flush() servenv.Init() - return parseAndRun() + return parseAndRun(cmd.Context()) } -func parseAndRun() error { +func parseAndRun(ctx context.Context) error { plannerVersion, _ := plancontext.PlannerNameToVersion(plannerVersionStr) if plannerVersionStr != "" && plannerVersion != querypb.ExecuteOptions_Gen4 { return fmt.Errorf("invalid value specified for planner-version of '%s' -- valid value is Gen4 or an empty value to use the default planner", plannerVersionStr) @@ -185,7 +185,6 @@ func parseAndRun() error { if err != nil { return err } - ctx := context.Background() ts := memorytopo.NewServer(ctx, vtexplain.Cell) srvTopoCounts := stats.NewCountersWithSingleLabel("", "Resilient srvtopo server operations", "type") vte, err := vtexplain.Init(ctx, env, ts, vschema, schema, ksShardMap, opts, srvTopoCounts) diff --git a/go/cmd/vtgate/cli/cli.go b/go/cmd/vtgate/cli/cli.go index 312396a5a3c..e0040e2d880 100644 --- a/go/cmd/vtgate/cli/cli.go +++ b/go/cmd/vtgate/cli/cli.go @@ -143,10 +143,16 @@ func run(cmd *cobra.Command, args []string) error { servenv.Init() + // Ensure we open the topo before we start the context, so that the + // defer that closes the topo runs after cancelling the context. + // This ensures that we've properly closed things like the watchers + // at that point. ts := topo.Open() defer ts.Close() - resilientServer = srvtopo.NewResilientServer(context.Background(), ts, srvTopoCounts) + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + resilientServer = srvtopo.NewResilientServer(ctx, ts, srvTopoCounts) tabletTypes := make([]topodatapb.TabletType, 0, 1) for _, tt := range tabletTypesToWait { @@ -159,7 +165,7 @@ func run(cmd *cobra.Command, args []string) error { return fmt.Errorf("tablet_types_to_wait must contain at least one serving tablet type") } - err := CheckCellFlags(context.Background(), resilientServer, cell, vtgate.CellsToWatch) + err := CheckCellFlags(ctx, resilientServer, cell, vtgate.CellsToWatch) if err != nil { return fmt.Errorf("cells_to_watch validation failed: %v", err) } @@ -176,7 +182,7 @@ func run(cmd *cobra.Command, args []string) error { } // pass nil for HealthCheck and it will be created - vtg := vtgate.Init(context.Background(), env, nil, resilientServer, cell, tabletTypes, plannerVersion) + vtg := vtgate.Init(ctx, env, nil, resilientServer, cell, tabletTypes, plannerVersion) servenv.OnRun(func() { // Flags are parsed now. Parse the template using the actual flag value and overwrite the current template. @@ -184,7 +190,7 @@ func run(cmd *cobra.Command, args []string) error { addStatusParts(vtg) }) servenv.OnClose(func() { - _ = vtg.Gateway().Close(context.Background()) + _ = vtg.Gateway().Close(ctx) }) servenv.RunDefault() diff --git a/go/cmd/vttablet/cli/cli.go b/go/cmd/vttablet/cli/cli.go index 45f8ac091dc..3fb1e98877f 100644 --- a/go/cmd/vttablet/cli/cli.go +++ b/go/cmd/vttablet/cli/cli.go @@ -110,6 +110,16 @@ func init() { func run(cmd *cobra.Command, args []string) error { servenv.Init() + // Ensure we open the topo before we start the context, so that the + // defer that closes the topo runs after cancelling the context. + // This ensures that we've properly closed things like the watchers + // at that point. + ts := topo.Open() + defer ts.Close() + + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + tabletAlias, err := topoproto.ParseTabletAlias(tabletPath) if err != nil { return fmt.Errorf("failed to parse --tablet-path: %w", err) @@ -131,8 +141,7 @@ func run(cmd *cobra.Command, args []string) error { return err } - ts := topo.Open() - qsc, err := createTabletServer(context.Background(), env, config, ts, tabletAlias, srvTopoCounts) + qsc, err := createTabletServer(ctx, env, config, ts, tabletAlias, srvTopoCounts) if err != nil { ts.Close() return err @@ -151,7 +160,7 @@ func run(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to parse --tablet-path: %w", err) } tm = &tabletmanager.TabletManager{ - BatchCtx: context.Background(), + BatchCtx: ctx, Env: env, TopoServer: ts, Cnf: mycnf, @@ -170,9 +179,6 @@ func run(cmd *cobra.Command, args []string) error { // Close the tm so that our topo entry gets pruned properly and any // background goroutines that use the topo connection are stopped. tm.Close() - - // tm uses ts. So, it should be closed after tm. - ts.Close() }) servenv.RunDefault()