From ba129c7d04e04859f8a5e859aa5f34c85dfa285f Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Tue, 22 Oct 2024 02:15:11 -0400 Subject: [PATCH] Flakes: Setup new fake server if it has gone away (#17023) Signed-off-by: Matt Lord --- go/mysql/auth_server_clientcert_test.go | 70 +++-- go/mysql/auth_server_static.go | 24 +- go/mysql/auth_server_static_test.go | 29 +- go/mysql/handshake_test.go | 29 +- go/mysql/replication_test.go | 5 + go/mysql/server_test.go | 369 +++++++++++------------- 6 files changed, 273 insertions(+), 253 deletions(-) diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index eff92053d94..72a1ecce87c 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -27,13 +27,24 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/tlstest" "vitess.io/vitess/go/vt/vttls" ) const clientCertUsername = "Client Cert" +// The listener's Accept() loop actually only ends on a connection +// error, which will occur when trying to connect after the listener +// has been closed. So this function closes the listener and then +// calls Connect to trigger the error which ends that work. +var cleanupListener = func(ctx context.Context, l *Listener, params *ConnParams) { + l.Close() + _, _ = Connect(ctx, params) +} + func TestValidCert(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := newAuthServerClientCert(string(MysqlClearPassword)) @@ -52,21 +63,6 @@ func TestValidCert(t *testing.T) { tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", clientCertUsername) tlstest.CreateCRL(root, tlstest.CA) - // Create the server with TLS config. - serverConfig, err := vttls.ServerConfig( - path.Join(root, "server-cert.pem"), - path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem"), - path.Join(root, "ca-crl.pem"), - "", - tls.VersionTLS12) - require.NoError(t, err, "TLSServerConfig failed: %v", err) - - l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -81,7 +77,20 @@ func TestValidCert(t *testing.T) { ServerName: "server.example.com", } - ctx := context.Background() + // Create the server with TLS config. + serverConfig, err := vttls.ServerConfig( + path.Join(root, "server-cert.pem"), + path.Join(root, "server-key.pem"), + path.Join(root, "ca-cert.pem"), + path.Join(root, "ca-crl.pem"), + "", + tls.VersionTLS12) + require.NoError(t, err, "TLSServerConfig failed: %v", err) + + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer cleanupListener(ctx, l, params) + conn, err := Connect(ctx, params) require.NoError(t, err, "Connect failed: %v", err) @@ -103,6 +112,7 @@ func TestValidCert(t *testing.T) { } func TestNoCert(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := newAuthServerClientCert(string(MysqlClearPassword)) @@ -120,6 +130,17 @@ func TestNoCert(t *testing.T) { tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") tlstest.CreateCRL(root, tlstest.CA) + // Setup the right parameters. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "", + SslMode: vttls.VerifyIdentity, + SslCa: path.Join(root, "ca-cert.pem"), + ServerName: "server.example.com", + } + // Create the server with TLS config. serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), @@ -131,22 +152,9 @@ func TestNoCert(t *testing.T) { require.NoError(t, err, "TLSServerConfig failed: %v", err) l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - - // Setup the right parameters. - params := &ConnParams{ - Host: host, - Port: port, - Uname: "user1", - Pass: "", - SslMode: vttls.VerifyIdentity, - SslCa: path.Join(root, "ca-cert.pem"), - ServerName: "server.example.com", - } + go l.Accept() + defer cleanupListener(ctx, l, params) - ctx := context.Background() conn, err := Connect(ctx, params) assert.Error(t, err, "Connect() should have errored due to no client cert") diff --git a/go/mysql/auth_server_static.go b/go/mysql/auth_server_static.go index d9e6decf5e5..46302bcabe1 100644 --- a/go/mysql/auth_server_static.go +++ b/go/mysql/auth_server_static.go @@ -50,8 +50,10 @@ type AuthServerStatic struct { // entries contains the users, passwords and user data. entries map[string][]*AuthServerStaticEntry + // Signal handling related fields. sigChan chan os.Signal ticker *time.Ticker + done chan struct{} // Tell the signal related goroutines to stop } // AuthServerStaticEntry stores the values for a given user. @@ -267,11 +269,17 @@ func (a *AuthServerStatic) installSignalHandlers() { return } + a.done = make(chan struct{}) a.sigChan = make(chan os.Signal, 1) signal.Notify(a.sigChan, syscall.SIGHUP) go func() { - for range a.sigChan { - a.reload() + for { + select { + case <-a.done: + return + case <-a.sigChan: + a.reload() + } } }() @@ -279,14 +287,22 @@ func (a *AuthServerStatic) installSignalHandlers() { if a.reloadInterval > 0 { a.ticker = time.NewTicker(a.reloadInterval) go func() { - for range a.ticker.C { - a.sigChan <- syscall.SIGHUP + for { + select { + case <-a.done: + return + case <-a.ticker.C: + a.sigChan <- syscall.SIGHUP + } } }() } } func (a *AuthServerStatic) close() { + if a.done != nil { + close(a.done) + } if a.ticker != nil { a.ticker.Stop() } diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index 12ae74e0d60..a808ce9b66b 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -25,6 +25,8 @@ import ( "time" "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/test/utils" ) // getEntries is a test-only method for AuthServerStatic. @@ -35,6 +37,7 @@ func (a *AuthServerStatic) getEntries() map[string][]*AuthServerStaticEntry { } func TestJsonConfigParser(t *testing.T) { + _ = utils.LeakCheckContext(t) // works with legacy format config := make(map[string][]*AuthServerStaticEntry) jsonConfig := "{\"mysql_user\":{\"Password\":\"123\", \"UserData\":\"dummy\"}, \"mysql_user_2\": {\"Password\": \"123\", \"UserData\": \"mysql_user_2\"}}" @@ -67,6 +70,7 @@ func TestJsonConfigParser(t *testing.T) { } func TestValidateHashGetter(t *testing.T) { + _ = utils.LeakCheckContext(t) jsonConfig := `{"mysql_user": [{"Password": "password", "UserData": "user.name", "Groups": ["user_group"]}]}` auth := NewAuthServerStatic("", jsonConfig, 0) @@ -90,6 +94,7 @@ func TestValidateHashGetter(t *testing.T) { } func TestHostMatcher(t *testing.T) { + _ = utils.LeakCheckContext(t) ip := net.ParseIP("192.168.0.1") addr := &net.TCPAddr{IP: ip, Port: 9999} match := MatchSourceHost(net.Addr(addr), "") @@ -105,9 +110,9 @@ func TestHostMatcher(t *testing.T) { } func TestStaticConfigHUP(t *testing.T) { + _ = utils.LeakCheckContext(t) tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json") require.NoError(t, err, "couldn't create temp file: %v", err) - defer os.Remove(tmpFile.Name()) oldStr := "str5" @@ -125,14 +130,19 @@ func TestStaticConfigHUP(t *testing.T) { mu.Lock() defer mu.Unlock() - // delete registered Auth server - clear(authServers) + // Delete registered Auth servers. + for k, v := range authServers { + if s, ok := v.(*AuthServerStatic); ok { + s.close() + } + delete(authServers, k) + } } func TestStaticConfigHUPWithRotation(t *testing.T) { + _ = utils.LeakCheckContext(t) tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json") require.NoError(t, err, "couldn't create temp file: %v", err) - defer os.Remove(tmpFile.Name()) oldStr := "str1" @@ -147,6 +157,16 @@ func TestStaticConfigHUPWithRotation(t *testing.T) { hupTestWithRotation(t, aStatic, tmpFile, oldStr, "str4") hupTestWithRotation(t, aStatic, tmpFile, "str4", "str5") + + mu.Lock() + defer mu.Unlock() + // Delete registered Auth servers. + for k, v := range authServers { + if s, ok := v.(*AuthServerStatic); ok { + s.close() + } + delete(authServers, k) + } } func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) { @@ -178,6 +198,7 @@ func hupTestWithRotation(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.Fi } func TestStaticPasswords(t *testing.T) { + _ = utils.LeakCheckContext(t) jsonConfig := ` { "user01": [{ "Password": "user01" }], diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index 284189c30e8..13ed1099e58 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -37,6 +37,7 @@ import ( // This file tests the handshake scenarios between our client and our server. func TestClearTextClientAuth(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) @@ -51,10 +52,6 @@ func TestClearTextClientAuth(t *testing.T) { defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -63,9 +60,10 @@ func TestClearTextClientAuth(t *testing.T) { Pass: "password1", SslMode: vttls.Disabled, } + go l.Accept() + defer cleanupListener(ctx, l, params) // Connection should fail, as server requires SSL for clear text auth. - ctx := context.Background() _, err = Connect(ctx, params) if err == nil || !strings.Contains(err.Error(), "Cannot use clear text authentication over non-SSL connections") { t.Fatalf("unexpected connection error: %v", err) @@ -92,6 +90,7 @@ func TestClearTextClientAuth(t *testing.T) { // TestSSLConnection creates a server with TLS support, a client that // also has SSL support, and connects them. func TestSSLConnection(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) @@ -103,7 +102,6 @@ func TestSSLConnection(t *testing.T) { // Create the listener, so we can get its host. l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed: %v", err) - defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port @@ -122,12 +120,6 @@ func TestSSLConnection(t *testing.T) { "", tls.VersionTLS12) require.NoError(t, err, "TLSServerConfig failed: %v", err) - - l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -141,20 +133,22 @@ func TestSSLConnection(t *testing.T) { SslKey: path.Join(root, "client-key.pem"), ServerName: "server.example.com", } + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer cleanupListener(ctx, l, params) t.Run("Basics", func(t *testing.T) { - testSSLConnectionBasics(t, params) + testSSLConnectionBasics(t, ctx, params) }) // Make sure clear text auth works over SSL. t.Run("ClearText", func(t *testing.T) { - testSSLConnectionClearText(t, params) + testSSLConnectionClearText(t, ctx, params) }) } -func testSSLConnectionClearText(t *testing.T, params *ConnParams) { +func testSSLConnectionClearText(t *testing.T, ctx context.Context, params *ConnParams) { // Create a client connection, connect. - ctx := context.Background() conn, err := Connect(ctx, params) require.NoError(t, err, "Connect failed: %v", err) @@ -170,9 +164,8 @@ func testSSLConnectionClearText(t *testing.T, params *ConnParams) { conn.writeComQuit() } -func testSSLConnectionBasics(t *testing.T, params *ConnParams) { +func testSSLConnectionBasics(t *testing.T, ctx context.Context, params *ConnParams) { // Create a client connection, connect. - ctx := context.Background() conn, err := Connect(ctx, params) require.NoError(t, err, "Connect failed: %v", err) diff --git a/go/mysql/replication_test.go b/go/mysql/replication_test.go index c397bc71b45..c9a54485497 100644 --- a/go/mysql/replication_test.go +++ b/go/mysql/replication_test.go @@ -23,10 +23,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/utils" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" ) func TestComBinlogDump(t *testing.T) { + _ = utils.LeakCheckContext(t) listener, sConn, cConn := createSocketPair(t) defer func() { listener.Close() @@ -72,6 +75,7 @@ func TestComBinlogDump(t *testing.T) { } func TestComBinlogDumpGTID(t *testing.T) { + _ = utils.LeakCheckContext(t) listener, sConn, cConn := createSocketPair(t) defer func() { listener.Close() @@ -161,6 +165,7 @@ func TestComBinlogDumpGTID(t *testing.T) { } func TestSendSemiSyncAck(t *testing.T) { + _ = utils.LeakCheckContext(t) listener, sConn, cConn := createSocketPair(t) defer func() { listener.Close() diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 88a9d3d67be..975bc964633 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -267,6 +267,7 @@ func getHostPort(t *testing.T, a net.Addr) (string, int) { } func TestConnectionFromListener(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -282,9 +283,6 @@ func TestConnectionFromListener(t *testing.T) { l, err := NewFromListener(listener, authServer, th, 0, 0, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) fmt.Printf("host: %s, port: %d\n", host, port) // Setup the right parameters. @@ -294,13 +292,16 @@ func TestConnectionFromListener(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer cleanupListener(ctx, l, params) - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") c.Close() } func TestConnectionWithoutSourceHost(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -309,13 +310,10 @@ func TestConnectionWithoutSourceHost(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -323,13 +321,16 @@ func TestConnectionWithoutSourceHost(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer cleanupListener(ctx, l, params) - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") c.Close() } func TestConnectionWithSourceHost(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -344,11 +345,7 @@ func TestConnectionWithSourceHost(t *testing.T) { l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -356,13 +353,16 @@ func TestConnectionWithSourceHost(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer cleanupListener(ctx, l, params) - _, err = Connect(context.Background(), params) + _, err = Connect(ctx, params) // target is localhost, should not work from tcp connection require.EqualError(t, err, "Access denied for user 'user1' (errno 1045) (sqlstate 28000)", "Should not be able to connect to server") } func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -377,11 +377,7 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -389,13 +385,16 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { Uname: "user1", Pass: "mysql_password", } + go l.Accept() + defer cleanupListener(ctx, l, params) - _, err = Connect(context.Background(), params) + _, err = Connect(ctx, params) // target is localhost, should not work from tcp connection require.EqualError(t, err, "Access denied for user 'user1' (errno 1045) (sqlstate 28000)", "Should not be able to connect to server") } func TestConnectionUnixSocket(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -415,22 +414,22 @@ func TestConnectionUnixSocket(t *testing.T) { l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - // Setup the right parameters. params := &ConnParams{ UnixSocket: unixSocket.Name(), Uname: "user1", Pass: "password1", } + go l.Accept() + defer cleanupListener(ctx, l, params) - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") c.Close() } func TestClientFoundRows(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -439,13 +438,10 @@ func TestClientFoundRows(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -453,9 +449,11 @@ func TestClientFoundRows(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer cleanupListener(ctx, l, params) // Test without flag. - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Connect failed") foundRows := th.LastConn().Capabilities & CapabilityClientFoundRows assert.Equal(t, uint32(0), foundRows, "FoundRows flag: %x, second bit must be 0", th.LastConn().Capabilities) @@ -464,7 +462,7 @@ func TestClientFoundRows(t *testing.T) { // Test with flag. params.Flags |= CapabilityClientFoundRows - c, err = Connect(context.Background(), params) + c, err = Connect(ctx, params) require.NoError(t, err, "Connect failed") foundRows = th.LastConn().Capabilities & CapabilityClientFoundRows assert.NotZero(t, foundRows, "FoundRows flag: %x, second bit must be set", th.LastConn().Capabilities) @@ -472,6 +470,7 @@ func TestClientFoundRows(t *testing.T) { } func TestConnCounts(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} user := "anotherNotYetConnectedUser1" @@ -483,13 +482,10 @@ func TestConnCounts(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Test with one new connection. params := &ConnParams{ Host: host, @@ -497,14 +493,16 @@ func TestConnCounts(t *testing.T) { Uname: user, Pass: passwd, } + go l.Accept() + defer cleanupListener(ctx, l, params) - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Connect failed") checkCountsForUser(t, user, 1) // Test with a second new connection. - c2, err := Connect(context.Background(), params) + c2, err := Connect(ctx, params) require.NoError(t, err) checkCountsForUser(t, user, 2) @@ -529,6 +527,7 @@ func checkCountsForUser(t assert.TestingT, user string, expected int64) { } func TestServer(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -537,14 +536,10 @@ func TestServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -552,6 +547,9 @@ func TestServer(t *testing.T) { Uname: "user1", Pass: "password1", } + l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) + go l.Accept() + defer cleanupListener(ctx, l, params) // Run a 'select rows' command with results. output, err := runMysqlWithErr(t, params, "select rows") @@ -629,6 +627,7 @@ func TestServer(t *testing.T) { } func TestServerStats(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -637,14 +636,10 @@ func TestServerStats(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -652,6 +647,9 @@ func TestServerStats(t *testing.T) { Uname: "user1", Pass: "password1", } + l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) + go l.Accept() + defer cleanupListener(ctx, l, params) timings.Reset() connCount.Reset() @@ -717,6 +715,7 @@ func TestServerStats(t *testing.T) { // TestClearTextServer creates a Server that needs clear text // passwords from the client. func TestClearTextServer(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) @@ -725,16 +724,10 @@ func TestClearTextServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - - version, _ := runMysql(t, nil, "--version") - isMariaDB := strings.Contains(version, "MariaDB") - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -742,6 +735,11 @@ func TestClearTextServer(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer cleanupListener(ctx, l, params) + + version, _ := runMysql(t, nil, "--version") + isMariaDB := strings.Contains(version, "MariaDB") // Run a 'select rows' command with results. This should fail // as clear text is not enabled by default on the client @@ -790,6 +788,7 @@ func TestClearTextServer(t *testing.T) { // TestDialogServer creates a Server that uses the dialog plugin on the client. func TestDialogServer(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlDialog) @@ -798,14 +797,11 @@ func TestDialogServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) l.AllowClearTextWithoutTLS.Store(true) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -814,6 +810,9 @@ func TestDialogServer(t *testing.T) { Pass: "password1", SslMode: vttls.Disabled, } + go l.Accept() + defer cleanupListener(ctx, l, params) + sql := "select rows" output, ok := runMysql(t, params, sql) if strings.Contains(output, "No such file or directory") || strings.Contains(output, "Authentication plugin 'dialog' cannot be loaded") { @@ -829,6 +828,7 @@ func TestDialogServer(t *testing.T) { // TestTLSServer creates a Server with TLS support, then uses mysql // client to connect to it. func TestTLSServer(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -837,46 +837,20 @@ func TestTLSServer(t *testing.T) { }} defer authServer.close() + // Create the certs. + root := t.TempDir() + tlstest.CreateCA(root) + tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") + tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + // Create the listener, so we can get its host. // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port - - // Create the certs. - root := t.TempDir() - tlstest.CreateCA(root) - tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") - tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") - - // Create the server with TLS config. - serverConfig, err := vttls.ServerConfig( - path.Join(root, "server-cert.pem"), - path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem"), - "", - "", - tls.VersionTLS12) - require.NoError(t, err) - l.TLSConfig.Store(serverConfig) - - var wg sync.WaitGroup - wg.Add(1) - go func(l *Listener) { - wg.Done() - l.Accept() - }(l) - // This is ensure the listener is called - wg.Wait() - // Sleep so that the Accept function is called as well.' - time.Sleep(3 * time.Second) - - connCountByTLSVer.ResetAll() // Setup the right parameters. params := &ConnParams{ Host: host, @@ -890,9 +864,23 @@ func TestTLSServer(t *testing.T) { SslKey: path.Join(root, "client-key.pem"), ServerName: "server.example.com", } + // Create the server with TLS config. + serverConfig, err := vttls.ServerConfig( + path.Join(root, "server-cert.pem"), + path.Join(root, "server-key.pem"), + path.Join(root, "ca-cert.pem"), + "", + "", + tls.VersionTLS12) + require.NoError(t, err) + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer cleanupListener(ctx, l, params) + + connCountByTLSVer.ResetAll() // Run a 'select rows' command with results. - conn, err := Connect(context.Background(), params) + conn, err := Connect(ctx, params) // output, ok := runMysql(t, params, "select rows") require.NoError(t, err) results, err := conn.ExecuteFetch("select rows", 1000, true) @@ -927,25 +915,9 @@ func TestTLSServer(t *testing.T) { // TestTLSRequired creates a Server with TLS required, then tests that an insecure mysql // client is rejected func TestTLSRequired(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} - authServer := NewAuthServerStatic("", "", 0) - authServer.entries["user1"] = []*AuthServerStaticEntry{{ - Password: "password1", - }} - defer authServer.close() - - // Create the listener, so we can get its host. - // Below, we are enabling --ssl-verify-server-cert, which adds - // a check that the common name of the certificate matches the - // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) - require.NoError(t, err) - defer l.Close() - - host := l.Addr().(*net.TCPAddr).IP.String() - port := l.Addr().(*net.TCPAddr).Port - // Create the certs. root := t.TempDir() tlstest.CreateCA(root) @@ -954,6 +926,13 @@ func TestTLSRequired(t *testing.T) { tlstest.CreateSignedCert(root, tlstest.CA, "03", "revoked-client", "Revoked Client Cert") tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "revoked-client") + params := &ConnParams{ + Uname: "user1", + Pass: "password1", + SslMode: vttls.Disabled, // TLS is disabled at first + ServerName: "server.example.com", + } + // Create the server with TLS config. serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), @@ -963,34 +942,49 @@ func TestTLSRequired(t *testing.T) { "", tls.VersionTLS12) require.NoError(t, err) - l.TLSConfig.Store(serverConfig) - l.RequireSecureTransport = true - - var wg sync.WaitGroup - wg.Add(1) - go func(l *Listener) { - wg.Done() - l.Accept() - }(l) - // This is ensure the listener is called - wg.Wait() - // Sleep so that the Accept function is called as well.' - time.Sleep(3 * time.Second) - - // Setup conn params without SSL. - params := &ConnParams{ - Host: host, - Port: port, - Uname: "user1", - Pass: "password1", - SslMode: vttls.Disabled, - ServerName: "server.example.com", + + authServer := NewAuthServerStatic("", "", 0) + authServer.entries["user1"] = []*AuthServerStaticEntry{{ + Password: "password1", + }} + defer authServer.close() + + var l *Listener + setupServer := func() { + // Create the listener, so we can get its host. + // Below, we are enabling --ssl-verify-server-cert, which adds + // a check that the common name of the certificate matches the + // server host name we connect to. + l, err = NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + require.NoError(t, err) + host := l.Addr().(*net.TCPAddr).IP.String() + port := l.Addr().(*net.TCPAddr).Port + l.TLSConfig.Store(serverConfig) + l.RequireSecureTransport = true + go l.Accept() + params.Host = host + params.Port = port } - conn, err := Connect(context.Background(), params) - require.NotNil(t, err) - require.Contains(t, err.Error(), "Code: UNAVAILABLE") - require.Contains(t, err.Error(), "server does not allow insecure connections, client must use SSL/TLS") - require.Contains(t, err.Error(), "(errno 1105) (sqlstate HY000)") + setupServer() + + defer cleanupListener(ctx, l, params) + + // This test calls Connect multiple times so we add handling for when the + // listener goes away for any reason. + connectWithGoneServerHandling := func() (*Conn, error) { + conn, err := Connect(ctx, params) + if sqlErr, ok := sqlerror.NewSQLErrorFromError(err).(*sqlerror.SQLError); ok && sqlErr.Num == sqlerror.CRConnHostError { + cleanupListener(ctx, l, params) + setupServer() + conn, err = Connect(ctx, params) + } + return conn, err + } + + conn, err := connectWithGoneServerHandling() + require.ErrorContains(t, err, "Code: UNAVAILABLE") + require.ErrorContains(t, err, "server does not allow insecure connections, client must use SSL/TLS") + require.ErrorContains(t, err, "(errno 1105) (sqlstate HY000)") if conn != nil { conn.Close() } @@ -1001,7 +995,7 @@ func TestTLSRequired(t *testing.T) { params.SslCert = path.Join(root, "client-cert.pem") params.SslKey = path.Join(root, "client-key.pem") - conn, err = Connect(context.Background(), params) + conn, err = connectWithGoneServerHandling() require.NoError(t, err) if conn != nil { conn.Close() @@ -1010,15 +1004,15 @@ func TestTLSRequired(t *testing.T) { // setup conn params with TLS, but with a revoked client certificate params.SslCert = path.Join(root, "revoked-client-cert.pem") params.SslKey = path.Join(root, "revoked-client-key.pem") - conn, err = Connect(context.Background(), params) - require.NotNil(t, err) - require.Contains(t, err.Error(), "remote error: tls: bad certificate") + conn, err = connectWithGoneServerHandling() + require.ErrorContains(t, err, "remote error: tls: bad certificate") if conn != nil { conn.Close() } } func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password) @@ -1027,19 +1021,17 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { } defer authServer.close() - // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) - require.NoError(t, err, "NewListener failed: %v", err) - defer l.Close() - host := l.Addr().(*net.TCPAddr).IP.String() - port := l.Addr().(*net.TCPAddr).Port - // Create the certs. root := t.TempDir() tlstest.CreateCA(root) tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + // Create the listener, so we can get its host. + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + require.NoError(t, err, "NewListener failed: %v", err) + host := l.Addr().(*net.TCPAddr).IP.String() + port := l.Addr().(*net.TCPAddr).Port // Create the server with TLS config. serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), @@ -1049,12 +1041,6 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { "", tls.VersionTLS12) require.NoError(t, err, "TLSServerConfig failed: %v", err) - - l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1068,10 +1054,11 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { SslKey: path.Join(root, "client-key.pem"), ServerName: "server.example.com", } + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer cleanupListener(ctx, l, params) // Connection should fail, as server requires SSL for caching_sha2_password. - ctx := context.Background() - conn, err := Connect(ctx, params) require.NoError(t, err, "unexpected connection error: %v", err) @@ -1107,12 +1094,11 @@ func newAuthServerAlwaysFallback(file, jsonConfig string, reloadInterval time.Du authMethod := NewSha2CachingAuthMethod(&alwaysFallbackAuth{}, a, a) a.methods = []AuthMethod{authMethod} - a.reload() - a.installSignalHandlers() return a } func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := newAuthServerAlwaysFallback("", "", 0) @@ -1121,19 +1107,17 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { } defer authServer.close() - // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) - require.NoError(t, err, "NewListener failed: %v", err) - defer l.Close() - host := l.Addr().(*net.TCPAddr).IP.String() - port := l.Addr().(*net.TCPAddr).Port - // Create the certs. root := t.TempDir() tlstest.CreateCA(root) tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + // Create the listener, so we can get its host. + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + require.NoError(t, err, "NewListener failed: %v", err) + host := l.Addr().(*net.TCPAddr).IP.String() + port := l.Addr().(*net.TCPAddr).Port // Create the server with TLS config. serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), @@ -1143,12 +1127,6 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { "", tls.VersionTLS12) require.NoError(t, err, "TLSServerConfig failed: %v", err) - - l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1162,10 +1140,11 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { SslKey: path.Join(root, "client-key.pem"), ServerName: "server.example.com", } + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer cleanupListener(ctx, l, params) // Connection should fail, as server requires SSL for caching_sha2_password. - ctx := context.Background() - conn, err := Connect(ctx, params) require.NoError(t, err, "unexpected connection error: %v", err) @@ -1182,6 +1161,7 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { } func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password) @@ -1193,13 +1173,8 @@ func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { // Create the listener. l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed: %v", err) - defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1208,9 +1183,10 @@ func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { Pass: "password1", SslMode: vttls.Disabled, } + go l.Accept() + defer cleanupListener(ctx, l, params) // Connection should fail, as server requires SSL for caching_sha2_password. - ctx := context.Background() _, err = Connect(ctx, params) if err == nil || !strings.Contains(err.Error(), "No authentication methods available for authentication") { t.Fatalf("unexpected connection error: %v", err) @@ -1225,6 +1201,7 @@ func checkCountForTLSVer(t *testing.T, version string, expected int64) { } func TestErrorCodes(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -1233,13 +1210,10 @@ func TestErrorCodes(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1247,8 +1221,9 @@ func TestErrorCodes(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer cleanupListener(ctx, l, params) - ctx := context.Background() client, err := Connect(ctx, params) require.NoError(t, err) @@ -1404,6 +1379,7 @@ func binaryPath(root, binary string) (string, error) { } func TestListenerShutdown(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) authServer.entries["user1"] = []*AuthServerStaticEntry{{ @@ -1411,13 +1387,10 @@ func TestListenerShutdown(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1425,9 +1398,12 @@ func TestListenerShutdown(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer cleanupListener(ctx, l, params) + connRefuse.Reset() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() conn, err := Connect(ctx, params) @@ -1438,7 +1414,7 @@ func TestListenerShutdown(t *testing.T) { l.Shutdown() - waitForConnRefuse(t, 1) + waitForConnRefuse(t, ctx, 1) err = conn.Ping() require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)") @@ -1450,8 +1426,8 @@ func TestListenerShutdown(t *testing.T) { require.Equal(t, "Server shutdown in progress", sqlErr.Message) } -func waitForConnRefuse(t *testing.T, valWanted int64) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) +func waitForConnRefuse(t *testing.T, ctx context.Context, valWanted int64) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() tick := time.NewTicker(100 * time.Millisecond) defer tick.Stop() @@ -1497,21 +1473,21 @@ func TestParseConnAttrs(t *testing.T) { } func TestServerFlush(t *testing.T) { + ctx := utils.LeakCheckContext(t) mysqlServerFlushDelay := 10 * time.Millisecond th := &testHandler{} l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, mysqlServerFlushDelay) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) params := &ConnParams{ Host: host, Port: port, } + go l.Accept() + defer cleanupListener(ctx, l, params) - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err) defer c.Close() @@ -1545,20 +1521,21 @@ func TestServerFlush(t *testing.T) { } func TestTcpKeepAlive(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} + l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) params := &ConnParams{ Host: host, Port: port, } + go l.Accept() + defer cleanupListener(ctx, l, params) // on connect, the tcp method should be called. - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err) defer c.Close() require.True(t, th.lastConn.keepAliveOn, "tcp property method not called")