diff --git a/go/vt/mysqlctl/backup.go b/go/vt/mysqlctl/backup.go index fb401966c50..0da2d18e06d 100644 --- a/go/vt/mysqlctl/backup.go +++ b/go/vt/mysqlctl/backup.go @@ -318,6 +318,10 @@ func ShouldRestore(ctx context.Context, params RestoreParams) (bool, error) { if err := params.Mysqld.Wait(ctx, params.Cnf); err != nil { return false, err } + if err := params.Mysqld.WaitForDBAGrants(ctx, DbaGrantWaitTime); err != nil { + params.Logger.Errorf("error waiting for the grants: %v", err) + return false, err + } return checkNoDB(ctx, params.Mysqld, params.DbName) } @@ -403,6 +407,10 @@ func Restore(ctx context.Context, params RestoreParams) (*BackupManifest, error) params.Logger.Errorf("mysqld is not running: %v", err) return nil, err } + if err = params.Mysqld.WaitForDBAGrants(ctx, DbaGrantWaitTime); err != nil { + params.Logger.Errorf("error waiting for the grants: %v", err) + return nil, err + } // Since this is an empty database make sure we start replication at the beginning if err := params.Mysqld.ResetReplication(ctx); err != nil { params.Logger.Errorf("error resetting replication: %v. Continuing", err) diff --git a/go/vt/mysqlctl/fakemysqldaemon.go b/go/vt/mysqlctl/fakemysqldaemon.go index 33a553a25e9..94c1f7f52c1 100644 --- a/go/vt/mysqlctl/fakemysqldaemon.go +++ b/go/vt/mysqlctl/fakemysqldaemon.go @@ -268,6 +268,10 @@ func (fmd *FakeMysqlDaemon) Wait(ctx context.Context, cnf *Mycnf) error { return nil } +func (fmd *FakeMysqlDaemon) WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error) { + return nil +} + // GetMysqlPort is part of the MysqlDaemon interface. func (fmd *FakeMysqlDaemon) GetMysqlPort() (int32, error) { if fmd.MysqlPort.Load() == -1 { diff --git a/go/vt/mysqlctl/mysql_daemon.go b/go/vt/mysqlctl/mysql_daemon.go index 9e8baebefd6..4829af3d4f7 100644 --- a/go/vt/mysqlctl/mysql_daemon.go +++ b/go/vt/mysqlctl/mysql_daemon.go @@ -40,6 +40,7 @@ type MysqlDaemon interface { ReadBinlogFilesTimestamps(ctx context.Context, req *mysqlctlpb.ReadBinlogFilesTimestampsRequest) (*mysqlctlpb.ReadBinlogFilesTimestampsResponse, error) ReinitConfig(ctx context.Context, cnf *Mycnf) error Wait(ctx context.Context, cnf *Mycnf) error + WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error) // GetMysqlPort returns the current port mysql is listening on. GetMysqlPort() (int32, error) diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index d866aa70f65..a1f6e257887 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -54,11 +54,10 @@ import ( "vitess.io/vitess/go/vt/hook" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/mysqlctl/mysqlctlclient" - "vitess.io/vitess/go/vt/servenv" - "vitess.io/vitess/go/vt/vterrors" - mysqlctlpb "vitess.io/vitess/go/vt/proto/mysqlctl" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/servenv" + "vitess.io/vitess/go/vt/vterrors" ) // The string we expect before the MySQL version number @@ -68,6 +67,9 @@ const versionStringPrefix = "Ver " // How many bytes from MySQL error log to sample for error messages const maxLogFileSampleSize = 4096 +// DbaGrantWaitTime is the amount of time to wait for the grants to have applied +const DbaGrantWaitTime = 10 * time.Second + var ( // DisableActiveReparents is a flag to disable active // reparents for safety reasons. It is used in three places: @@ -514,6 +516,40 @@ func (mysqld *Mysqld) Wait(ctx context.Context, cnf *Mycnf) error { return mysqld.wait(ctx, cnf, params) } +// WaitForDBAGrants waits for the grants to have applied for all the users. +func (mysqld *Mysqld) WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error) { + if waitTime == 0 { + return nil + } + timer := time.NewTimer(waitTime) + ctx, cancel := context.WithTimeout(ctx, waitTime) + defer cancel() + for { + conn, connErr := dbconnpool.NewDBConnection(ctx, mysqld.dbcfgs.DbaConnector()) + if connErr == nil { + res, fetchErr := conn.ExecuteFetch("SHOW GRANTS", 1000, false) + conn.Close() + if fetchErr != nil { + log.Errorf("Error running SHOW GRANTS - %v", fetchErr) + } + if fetchErr == nil && res != nil && len(res.Rows) > 0 && len(res.Rows[0]) > 0 { + privileges := res.Rows[0][0].ToString() + // In MySQL 8.0, all the privileges are listed out explicitly, so we can search for SUPER in the output. + // In MySQL 5.7, all the privileges are not listed explicitly, instead ALL PRIVILEGES is written, so we search for that too. + if strings.Contains(privileges, "SUPER") || strings.Contains(privileges, "ALL PRIVILEGES") { + return nil + } + } + } + select { + case <-timer.C: + return fmt.Errorf("timed out after %v waiting for the dba user to have the required permissions", waitTime) + default: + time.Sleep(100 * time.Millisecond) + } + } +} + // wait is the internal version of Wait, that takes credentials. func (mysqld *Mysqld) wait(ctx context.Context, cnf *Mycnf, params *mysql.ConnParams) error { log.Infof("Waiting for mysqld socket file (%v) to be ready...", cnf.SocketFile) diff --git a/go/vt/vttablet/tabletmanager/tm_init.go b/go/vt/vttablet/tabletmanager/tm_init.go index f8931691ed5..6a1a4f1b730 100644 --- a/go/vt/vttablet/tabletmanager/tm_init.go +++ b/go/vt/vttablet/tabletmanager/tm_init.go @@ -79,7 +79,6 @@ import ( const ( // Query rules from denylist denyListQueryList string = "DenyListQueryRules" - dbaGrantWaitTime = 10 * time.Second ) var ( @@ -424,7 +423,7 @@ func (tm *TabletManager) Start(tablet *topodatapb.Tablet, config *tabletenv.Tabl } // Make sure we have the correct privileges for the DBA user before we start the state manager. - err = tm.waitForDBAGrants(config, dbaGrantWaitTime) + err = tm.waitForDBAGrants(config, mysqlctl.DbaGrantWaitTime) if err != nil { return err } @@ -822,7 +821,7 @@ func (tm *TabletManager) handleRestore(ctx context.Context, config *tabletenv.Ta } // Make sure we have the correct privileges for the DBA user before we start the state manager. - err := tm.waitForDBAGrants(config, dbaGrantWaitTime) + err := tm.waitForDBAGrants(config, mysqlctl.DbaGrantWaitTime) if err != nil { log.Exitf("Failed waiting for DBA grants: %v", err) } @@ -849,33 +848,7 @@ func (tm *TabletManager) waitForDBAGrants(config *tabletenv.TabletConfig, waitTi if config == nil || config.DB.HasGlobalSettings() || waitTime == 0 { return nil } - timer := time.NewTimer(waitTime) - ctx, cancel := context.WithTimeout(context.Background(), waitTime) - defer cancel() - for { - conn, connErr := dbconnpool.NewDBConnection(ctx, config.DB.DbaConnector()) - if connErr == nil { - res, fetchErr := conn.ExecuteFetch("SHOW GRANTS", 1000, false) - conn.Close() - if fetchErr != nil { - log.Errorf("Error running SHOW GRANTS - %v", fetchErr) - } - if fetchErr == nil && res != nil && len(res.Rows) > 0 && len(res.Rows[0]) > 0 { - privileges := res.Rows[0][0].ToString() - // In MySQL 8.0, all the privileges are listed out explicitly, so we can search for SUPER in the output. - // In MySQL 5.7, all the privileges are not listed explicitly, instead ALL PRIVILEGES is written, so we search for that too. - if strings.Contains(privileges, "SUPER") || strings.Contains(privileges, "ALL PRIVILEGES") { - return nil - } - } - } - select { - case <-timer.C: - return fmt.Errorf("timed out after %v waiting for the dba user to have the required permissions", waitTime) - default: - time.Sleep(100 * time.Millisecond) - } - } + return tm.MysqlDaemon.WaitForDBAGrants(context.Background(), waitTime) } func (tm *TabletManager) exportStats() { diff --git a/go/vt/vttablet/tabletmanager/tm_init_test.go b/go/vt/vttablet/tabletmanager/tm_init_test.go index c44bb846eb3..0e3256a1aac 100644 --- a/go/vt/vttablet/tabletmanager/tm_init_test.go +++ b/go/vt/vttablet/tabletmanager/tm_init_test.go @@ -858,8 +858,13 @@ func TestWaitForDBAGrants(t *testing.T) { t.Run(tt.name, func(t *testing.T) { config, cleanup := tt.setupFunc(t) defer cleanup() + var dm mysqlctl.MysqlDaemon + if config != nil { + dm = mysqlctl.NewMysqld(config.DB) + } tm := TabletManager{ _waitForGrantsComplete: make(chan struct{}), + MysqlDaemon: dm, } err := tm.waitForDBAGrants(config, tt.waitTime) if tt.errWanted == "" {