diff --git a/go/cmd/vtgate/cli/plugin_auth_clientcert.go b/go/cmd/vtgate/cli/plugin_auth_clientcert.go index 1a1334e71ba..d486669847f 100644 --- a/go/cmd/vtgate/cli/plugin_auth_clientcert.go +++ b/go/cmd/vtgate/cli/plugin_auth_clientcert.go @@ -23,6 +23,10 @@ import ( "vitess.io/vitess/go/vt/vtgate" ) +var clientcertAuthMethod string + func init() { - vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerClientCert() }) + Main.Flags().StringVar(&clientcertAuthMethod, "mysql_clientcert_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") + + vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerClientCert(clientcertAuthMethod) }) } diff --git a/go/cmd/vtgate/cli/plugin_auth_ldap.go b/go/cmd/vtgate/cli/plugin_auth_ldap.go index 7aab7e9c7f4..f8312267504 100644 --- a/go/cmd/vtgate/cli/plugin_auth_ldap.go +++ b/go/cmd/vtgate/cli/plugin_auth_ldap.go @@ -19,10 +19,21 @@ package cli // This plugin imports ldapauthserver to register the LDAP implementation of AuthServer. import ( + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/mysql/ldapauthserver" "vitess.io/vitess/go/vt/vtgate" ) +var ( + ldapAuthConfigFile string + ldapAuthConfigString string + ldapAuthMethod string +) + func init() { - vtgate.RegisterPluginInitializer(func() { ldapauthserver.Init() }) + Main.Flags().StringVar(&ldapAuthConfigFile, "mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.") + Main.Flags().StringVar(&ldapAuthConfigString, "mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.") + Main.Flags().StringVar(&ldapAuthMethod, "mysql_ldap_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") + + vtgate.RegisterPluginInitializer(func() { ldapauthserver.Init(ldapAuthConfigFile, ldapAuthConfigString, ldapAuthMethod) }) } diff --git a/go/cmd/vtgate/cli/plugin_auth_static.go b/go/cmd/vtgate/cli/plugin_auth_static.go index 76cdf8318ba..7ed0e7b8f61 100644 --- a/go/cmd/vtgate/cli/plugin_auth_static.go +++ b/go/cmd/vtgate/cli/plugin_auth_static.go @@ -19,10 +19,24 @@ package cli // This plugin imports staticauthserver to register the flat-file implementation of AuthServer. import ( + "time" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/vt/vtgate" ) +var ( + mysqlAuthServerStaticFile string + mysqlAuthServerStaticString string + mysqlAuthServerStaticReloadInterval time.Duration +) + func init() { - vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerStatic() }) + Main.Flags().StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.") + Main.Flags().StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.") + Main.Flags().DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials") + + vtgate.RegisterPluginInitializer(func() { + mysql.InitAuthServerStatic(mysqlAuthServerStaticFile, mysqlAuthServerStaticString, mysqlAuthServerStaticReloadInterval) + }) } diff --git a/go/cmd/vtgate/cli/plugin_auth_vault.go b/go/cmd/vtgate/cli/plugin_auth_vault.go index fe5fe2207d4..a119d2d389b 100644 --- a/go/cmd/vtgate/cli/plugin_auth_vault.go +++ b/go/cmd/vtgate/cli/plugin_auth_vault.go @@ -19,10 +19,36 @@ package cli // This plugin imports InitAuthServerVault to register the HashiCorp Vault implementation of AuthServer. import ( + "time" + "vitess.io/vitess/go/mysql/vault" "vitess.io/vitess/go/vt/vtgate" ) +var ( + vaultAddr string + vaultTimeout time.Duration + vaultCACert string + vaultPath string + vaultCacheTTL time.Duration + vaultTokenFile string + vaultRoleID string + vaultRoleSecretIDFile string + vaultRoleMountPoint string +) + func init() { - vtgate.RegisterPluginInitializer(func() { vault.InitAuthServerVault() }) + Main.Flags().StringVar(&vaultAddr, "mysql_auth_vault_addr", "", "URL to Vault server") + Main.Flags().DurationVar(&vaultTimeout, "mysql_auth_vault_timeout", 10*time.Second, "Timeout for vault API operations") + Main.Flags().StringVar(&vaultCACert, "mysql_auth_vault_tls_ca", "", "Path to CA PEM for validating Vault server certificate") + Main.Flags().StringVar(&vaultPath, "mysql_auth_vault_path", "", "Vault path to vtgate credentials JSON blob, e.g.: secret/data/prod/vtgatecreds") + Main.Flags().DurationVar(&vaultCacheTTL, "mysql_auth_vault_ttl", 30*time.Minute, "How long to cache vtgate credentials from the Vault server") + Main.Flags().StringVar(&vaultTokenFile, "mysql_auth_vault_tokenfile", "", "Path to file containing Vault auth token; token can also be passed using VAULT_TOKEN environment variable") + Main.Flags().StringVar(&vaultRoleID, "mysql_auth_vault_roleid", "", "Vault AppRole id; can also be passed using VAULT_ROLEID environment variable") + Main.Flags().StringVar(&vaultRoleSecretIDFile, "mysql_auth_vault_role_secretidfile", "", "Path to file containing Vault AppRole secret_id; can also be passed using VAULT_SECRETID environment variable") + Main.Flags().StringVar(&vaultRoleMountPoint, "mysql_auth_vault_role_mountpoint", "approle", "Vault AppRole mountpoint; can also be passed using VAULT_MOUNTPOINT environment variable") + + vtgate.RegisterPluginInitializer(func() { + vault.InitAuthServerVault(vaultAddr, vaultTimeout, vaultCACert, vaultPath, vaultCacheTTL, vaultTokenFile, vaultRoleID, vaultRoleSecretIDFile, vaultRoleMountPoint) + }) } diff --git a/go/flags/endtoend/vtcombo.txt b/go/flags/endtoend/vtcombo.txt index d32df437787..cc3b55ee9cd 100644 --- a/go/flags/endtoend/vtcombo.txt +++ b/go/flags/endtoend/vtcombo.txt @@ -228,6 +228,7 @@ Flags: --mysql_default_workload string Default session workload (OLTP, OLAP, DBA) (default "OLTP") --mysql_port int mysql port (default 3306) --mysql_server_bind_address string Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance. + --mysql_server_flush_delay duration Delay after which buffered response will be flushed to the client. (default 100ms) --mysql_server_port int If set, also listen for MySQL binary protocol connections on this port. (default -1) --mysql_server_query_timeout duration mysql query timeout --mysql_server_read_timeout duration connection read timeout diff --git a/go/mysql/auth_server_clientcert.go b/go/mysql/auth_server_clientcert.go index 10a01487208..bb0a4028683 100644 --- a/go/mysql/auth_server_clientcert.go +++ b/go/mysql/auth_server_clientcert.go @@ -23,17 +23,8 @@ import ( "github.com/spf13/pflag" "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/servenv" ) -var clientcertAuthMethod string - -func init() { - servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) { - fs.StringVar(&clientcertAuthMethod, "mysql_clientcert_auth_method", string(MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") - }) -} - // AuthServerClientCert implements AuthServer which enforces client side certificates type AuthServerClientCert struct { methods []AuthMethod @@ -41,7 +32,7 @@ type AuthServerClientCert struct { } // InitAuthServerClientCert is public so it can be called from plugin_auth_clientcert.go (go/cmd/vtgate) -func InitAuthServerClientCert() { +func InitAuthServerClientCert(clientcertAuthMethod string) { if pflag.CommandLine.Lookup("mysql_server_ssl_ca").Value.String() == "" { log.Info("Not configuring AuthServerClientCert because mysql_server_ssl_ca is empty") return @@ -50,11 +41,11 @@ func InitAuthServerClientCert() { log.Exitf("Invalid mysql_clientcert_auth_method value: only support mysql_clear_password or dialog") } - ascc := newAuthServerClientCert() + ascc := newAuthServerClientCert(clientcertAuthMethod) RegisterAuthServer("clientcert", ascc) } -func newAuthServerClientCert() *AuthServerClientCert { +func newAuthServerClientCert(clientcertAuthMethod string) *AuthServerClientCert { ascc := &AuthServerClientCert{ Method: AuthMethodDescription(clientcertAuthMethod), } diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index 28ed19fd9c5..3314116e953 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -33,19 +33,13 @@ import ( const clientCertUsername = "Client Cert" -func init() { - // These tests do not invoke the servenv.Parse codepaths, so this default - // does not get set by the OnParseFor hook. - clientcertAuthMethod = string(MysqlClearPassword) -} - func TestValidCert(t *testing.T) { th := &testHandler{} - authServer := newAuthServerClientCert() + authServer := newAuthServerClientCert(string(MysqlClearPassword)) // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -111,10 +105,10 @@ func TestValidCert(t *testing.T) { func TestNoCert(t *testing.T) { th := &testHandler{} - authServer := newAuthServerClientCert() + authServer := newAuthServerClientCert(string(MysqlClearPassword)) // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() diff --git a/go/mysql/auth_server_static.go b/go/mysql/auth_server_static.go index fae886039f0..6e3a9693c69 100644 --- a/go/mysql/auth_server_static.go +++ b/go/mysql/auth_server_static.go @@ -27,34 +27,15 @@ import ( "syscall" "time" - "github.com/spf13/pflag" - "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" ) -var ( - mysqlAuthServerStaticFile string - mysqlAuthServerStaticString string - mysqlAuthServerStaticReloadInterval time.Duration - mysqlServerFlushDelay = 100 * time.Millisecond -) - -func init() { - servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) { - fs.StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.") - fs.StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.") - fs.DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials") - fs.DurationVar(&mysqlServerFlushDelay, "mysql_server_flush_delay", mysqlServerFlushDelay, "Delay after which buffered response will be flushed to the client.") - }) -} - const ( localhostName = "localhost" ) @@ -94,7 +75,7 @@ type AuthServerStaticEntry struct { } // InitAuthServerStatic Handles initializing the AuthServerStatic if necessary. -func InitAuthServerStatic() { +func InitAuthServerStatic(mysqlAuthServerStaticFile, mysqlAuthServerStaticString string, mysqlAuthServerStaticReloadInterval time.Duration) { // Check parameters. if mysqlAuthServerStaticFile == "" && mysqlAuthServerStaticString == "" { // Not configured, nothing to do. diff --git a/go/mysql/client.go b/go/mysql/client.go index c4dd87d95cc..db1fd0cb68f 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -106,7 +106,7 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) { } // Send the connection back, so the other side can close it. - c := newConn(conn) + c := newConn(conn, params.FlushDelay) status <- connectResult{ c: c, } diff --git a/go/mysql/client_test.go b/go/mysql/client_test.go index c349cdcd531..057a8584679 100644 --- a/go/mysql/client_test.go +++ b/go/mysql/client_test.go @@ -151,7 +151,7 @@ func TestTLSClientDisabled(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() @@ -223,7 +223,7 @@ func TestTLSClientPreferredDefault(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() @@ -296,7 +296,7 @@ func TestTLSClientRequired(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() @@ -343,7 +343,7 @@ func TestTLSClientVerifyCA(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() @@ -426,7 +426,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 1908875db49..85a8ffd4027 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -43,6 +43,8 @@ import ( ) const ( + DefaultFlushDelay = 100 * time.Millisecond + // connBufferSize is how much we buffer for reading and // writing. It is also how much we allocate for ephemeral buffers. connBufferSize = 16 * 1024 @@ -128,6 +130,7 @@ type Conn struct { bufferedReader *bufio.Reader flushTimer *time.Timer + flushDelay time.Duration header [packetHeaderSize]byte // Keep track of how and of the buffer we allocated for an @@ -246,10 +249,14 @@ var readersPool = sync.Pool{New: func() any { return bufio.NewReaderSize(nil, co // newConn is an internal method to create a Conn. Used by client and server // side for common creation code. -func newConn(conn net.Conn) *Conn { +func newConn(conn net.Conn, flushDelay time.Duration) *Conn { + if flushDelay == 0 { + flushDelay = DefaultFlushDelay + } return &Conn{ conn: conn, bufferedReader: bufio.NewReaderSize(conn, connBufferSize), + flushDelay: flushDelay, } } @@ -274,6 +281,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn { listener: listener, PrepareData: make(map[uint32]*PrepareData), keepAliveOn: enabledKeepAlive, + flushDelay: listener.flushDelay, } if listener.connReadBufferSize > 0 { @@ -347,7 +355,7 @@ func (c *Conn) returnReader() { // startFlushTimer must be called while holding lock on bufMu. func (c *Conn) startFlushTimer() { if c.flushTimer == nil { - c.flushTimer = time.AfterFunc(mysqlServerFlushDelay, func() { + c.flushTimer = time.AfterFunc(c.flushDelay, func() { c.bufMu.Lock() defer c.bufMu.Unlock() @@ -357,7 +365,7 @@ func (c *Conn) startFlushTimer() { c.bufferedWriter.Flush() }) } else { - c.flushTimer.Reset(mysqlServerFlushDelay) + c.flushTimer.Reset(c.flushDelay) } } diff --git a/go/mysql/conn_fake.go b/go/mysql/conn_fake.go index e61f90d33f1..c20d09a2f6d 100644 --- a/go/mysql/conn_fake.go +++ b/go/mysql/conn_fake.go @@ -84,7 +84,7 @@ var _ net.Addr = (*mockAddress)(nil) // GetTestConn returns a conn for testing purpose only. func GetTestConn() *Conn { - return newConn(testConn{}) + return newConn(testConn{}, DefaultFlushDelay) } // GetTestServerConn is only meant to be used for testing. diff --git a/go/mysql/conn_flaky_test.go b/go/mysql/conn_flaky_test.go index 9df52a47589..0057aff5aa6 100644 --- a/go/mysql/conn_flaky_test.go +++ b/go/mysql/conn_flaky_test.go @@ -77,8 +77,8 @@ func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) { require.Nil(t, serverErr, "Accept failed: %v", serverErr) // Create a Conn on both sides. - cConn := newConn(clientConn) - sConn := newConn(serverConn) + cConn := newConn(clientConn, DefaultFlushDelay) + sConn := newConn(serverConn, DefaultFlushDelay) sConn.PrepareData = map[uint32]*PrepareData{} return listener, sConn, cConn @@ -942,7 +942,7 @@ func TestConnectionErrorWhileWritingComQuery(t *testing.T) { pos: -1, queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComQuery, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31}, - }) + }, DefaultFlushDelay) // this handler will return an error on the first run, and fail the test if it's run more times errorString := make([]byte, 17000) @@ -958,7 +958,7 @@ func TestConnectionErrorWhileWritingComStmtSendLongData(t *testing.T) { pos: -1, queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComStmtSendLongData, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31}, - }) + }, DefaultFlushDelay) // this handler will return an error on the first run, and fail the test if it's run more times handler := &testRun{t: t, err: fmt.Errorf("not used")} @@ -972,7 +972,7 @@ func TestConnectionErrorWhileWritingComPrepare(t *testing.T) { writeToPass: []bool{false}, pos: -1, queryPacket: []byte{0x01, 0x00, 0x00, 0x00, ComPrepare}, - }) + }, DefaultFlushDelay) sConn.Capabilities = sConn.Capabilities | CapabilityClientMultiStatements // this handler will return an error on the first run, and fail the test if it's run more times handler := &testRun{t: t, err: fmt.Errorf("not used")} @@ -987,7 +987,7 @@ func TestConnectionErrorWhileWritingComStmtExecute(t *testing.T) { pos: -1, queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComStmtExecute, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31}, - }) + }, DefaultFlushDelay) // this handler will return an error on the first run, and fail the test if it's run more times handler := &testRun{t: t, err: fmt.Errorf("not used")} res := sConn.handleNextCommand(handler) diff --git a/go/mysql/conn_params.go b/go/mysql/conn_params.go index 061aa23f220..83b2dc78304 100644 --- a/go/mysql/conn_params.go +++ b/go/mysql/conn_params.go @@ -17,6 +17,8 @@ limitations under the License. package mysql import ( + "time" + "vitess.io/vitess/go/vt/vttls" ) @@ -57,6 +59,9 @@ type ConnParams struct { // for informative purposes. It has no programmatic value. Returning this field is // disabled by default. EnableQueryInfo bool + + // FlushDelay is the delay after which buffered response will be flushed to the client. + FlushDelay time.Duration } // EnableSSL will set the right flag on the parameters. diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index cb3d20ae04b..bd5435e6988 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -189,7 +189,7 @@ func New(t testing.TB) *DB { authServer := mysql.NewAuthServerNone() // Start listening. - db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0) + db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0, "8.0.30-Vitess") if err != nil { t.Fatalf("NewListener failed: %v", err) } diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index c2b27d6f6d4..57ed604daae 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -45,7 +45,7 @@ func TestClearTextClientAuth(t *testing.T) { defer authServer.close() // Create the listener. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -99,7 +99,7 @@ func TestSSLConnection(t *testing.T) { defer authServer.close() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() diff --git a/go/mysql/ldapauthserver/auth_server_ldap.go b/go/mysql/ldapauthserver/auth_server_ldap.go index d5fcea027ac..5e6010fac0e 100644 --- a/go/mysql/ldapauthserver/auth_server_ldap.go +++ b/go/mysql/ldapauthserver/auth_server_ldap.go @@ -24,32 +24,16 @@ import ( "sync" "time" - "github.com/spf13/pflag" - ldap "gopkg.in/ldap.v2" + "gopkg.in/ldap.v2" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vttls" querypb "vitess.io/vitess/go/vt/proto/query" ) -var ( - ldapAuthConfigFile string - ldapAuthConfigString string - ldapAuthMethod string -) - -func init() { - servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) { - fs.StringVar(&ldapAuthConfigFile, "mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.") - fs.StringVar(&ldapAuthConfigString, "mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.") - fs.StringVar(&ldapAuthMethod, "mysql_ldap_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") - }) -} - // AuthServerLdap implements AuthServer with an LDAP backend type AuthServerLdap struct { Client @@ -63,7 +47,7 @@ type AuthServerLdap struct { } // Init is public so it can be called from plugin_auth_ldap.go (go/cmd/vtgate) -func Init() { +func Init(ldapAuthConfigFile, ldapAuthConfigString, ldapAuthMethod string) { if ldapAuthConfigFile == "" && ldapAuthConfigString == "" { log.Infof("Not configuring AuthServerLdap because mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are empty") return diff --git a/go/mysql/mysql_fuzzer.go b/go/mysql/mysql_fuzzer.go index 2a3e797a797..057f2ac01c3 100644 --- a/go/mysql/mysql_fuzzer.go +++ b/go/mysql/mysql_fuzzer.go @@ -76,8 +76,8 @@ func createFuzzingSocketPair() (net.Listener, *Conn, *Conn) { } // Create a Conn on both sides. - cConn := newConn(clientConn) - sConn := newConn(serverConn) + cConn := newConn(clientConn, DefaultFlushDelay) + sConn := newConn(serverConn, DefaultFlushDelay) return listener, sConn, cConn } @@ -196,7 +196,7 @@ func FuzzHandleNextCommand(data []byte) int { writeToPass: []bool{false}, pos: -1, queryPacket: data, - }) + }, DefaultFlushDelay) sConn.PrepareData = map[uint32]*PrepareData{} handler := &fuzztestRun{} @@ -327,7 +327,7 @@ func FuzzTLSServer(data []byte) int { Password: "password1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, 0, "8.0.30-Vitess") if err != nil { return -1 } diff --git a/go/mysql/server.go b/go/mysql/server.go index ec2d7538daa..1e321bae9d4 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -39,7 +39,6 @@ import ( "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" ) @@ -212,6 +211,9 @@ type Listener struct { // handled further by the MySQL handler. An non-nil error will stop // processing the connection by the MySQL handler. PreHandleFunc func(context.Context, net.Conn, uint32) (net.Conn, error) + + // flushDelay is the delay after which buffered response will be flushed to the client. + flushDelay time.Duration } // NewFromListener creates a new mysql listener from an existing net.Listener @@ -223,6 +225,8 @@ func NewFromListener( connWriteTimeout time.Duration, connBufferPooling bool, keepAlivePeriod time.Duration, + flushDelay time.Duration, + mysqlServerVersion string, ) (*Listener, error) { cfg := ListenerConfig{ Listener: l, @@ -233,6 +237,8 @@ func NewFromListener( ConnReadBufferSize: connBufferSize, ConnBufferPooling: connBufferPooling, ConnKeepAlivePeriod: keepAlivePeriod, + FlushDelay: flushDelay, + MySQLServerVersion: mysqlServerVersion, } return NewListenerWithConfig(cfg) } @@ -247,6 +253,8 @@ func NewListener( proxyProtocol bool, connBufferPooling bool, keepAlivePeriod time.Duration, + flushDelay time.Duration, + mysqlServerVersion string, ) (*Listener, error) { listener, err := net.Listen(protocol, address) if err != nil { @@ -254,10 +262,10 @@ func NewListener( } if proxyProtocol { proxyListener := &proxyproto.Listener{Listener: listener} - return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod) + return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, mysqlServerVersion) } - return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod) + return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, mysqlServerVersion) } // ListenerConfig should be used with NewListenerWithConfig to specify listener parameters. @@ -273,6 +281,8 @@ type ListenerConfig struct { ConnReadBufferSize int ConnBufferPooling bool ConnKeepAlivePeriod time.Duration + FlushDelay time.Duration + MySQLServerVersion string } // NewListenerWithConfig creates new listener using provided config. There are @@ -293,13 +303,14 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) { authServer: cfg.AuthServer, handler: cfg.Handler, listener: l, - ServerVersion: servenv.AppVersion.MySQLVersion(), + ServerVersion: cfg.MySQLServerVersion, connectionID: 1, connReadTimeout: cfg.ConnReadTimeout, connWriteTimeout: cfg.ConnWriteTimeout, connReadBufferSize: cfg.ConnReadBufferSize, connBufferPooling: cfg.ConnBufferPooling, connKeepAlivePeriod: cfg.ConnKeepAlivePeriod, + flushDelay: cfg.FlushDelay, }, nil } diff --git a/go/mysql/server_flaky_test.go b/go/mysql/server_flaky_test.go index 509fccaa47a..e68eab37e9a 100644 --- a/go/mysql/server_flaky_test.go +++ b/go/mysql/server_flaky_test.go @@ -263,6 +263,8 @@ func getHostPort(t *testing.T, a net.Addr) (string, int) { return host, port } +const mysqlVersion = "8.0.30-Vitess" + func TestConnectionFromListener(t *testing.T) { th := &testHandler{} @@ -277,7 +279,7 @@ func TestConnectionFromListener(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:") require.NoError(t, err, "net.Listener failed") - l, err := NewFromListener(listener, authServer, th, 0, 0, false, 0) + l, err := NewFromListener(listener, authServer, th, 0, 0, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -306,7 +308,7 @@ func TestConnectionWithoutSourceHost(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -339,7 +341,7 @@ func TestConnectionWithSourceHost(t *testing.T) { } defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -372,7 +374,7 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { } defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -410,7 +412,7 @@ func TestConnectionUnixSocket(t *testing.T) { os.Remove(unixSocket.Name()) - l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false, 0) + l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -436,7 +438,7 @@ func TestClientFoundRows(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -485,7 +487,7 @@ func TestConnCounts(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -542,7 +544,7 @@ func TestServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) defer l.Close() @@ -642,7 +644,7 @@ func TestServerStats(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) defer l.Close() @@ -716,7 +718,7 @@ func TestClearTextServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() go l.Accept() @@ -789,7 +791,7 @@ func TestDialogServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) l.AllowClearTextWithoutTLS.Store(true) defer l.Close() @@ -832,7 +834,7 @@ func TestTLSServer(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() @@ -930,7 +932,7 @@ func TestTLSRequired(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() @@ -1019,7 +1021,7 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { defer authServer.close() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -1113,7 +1115,7 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { defer authServer.close() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -1182,7 +1184,7 @@ func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { defer authServer.close() // Create the listener. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -1224,7 +1226,7 @@ func TestErrorCodes(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() go l.Accept() @@ -1402,7 +1404,7 @@ func TestListenerShutdown(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() go l.Accept() @@ -1470,12 +1472,10 @@ func TestParseConnAttrs(t *testing.T) { } func TestServerFlush(t *testing.T) { - defer func(saved time.Duration) { mysqlServerFlushDelay = saved }(mysqlServerFlushDelay) - mysqlServerFlushDelay = 10 * time.Millisecond - + mysqlServerFlushDelay := 10 * time.Millisecond th := &testHandler{} - l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, mysqlServerFlushDelay, mysqlVersion) require.NoError(t, err) defer l.Close() go l.Accept() @@ -1521,7 +1521,7 @@ func TestServerFlush(t *testing.T) { func TestTcpKeepAlive(t *testing.T) { th := &testHandler{} - l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0) + l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, 0, mysqlVersion) require.NoError(t, err) defer l.Close() go l.Accept() diff --git a/go/mysql/vault/auth_server_vault.go b/go/mysql/vault/auth_server_vault.go index ccdef9f1d53..d2bc2548817 100644 --- a/go/mysql/vault/auth_server_vault.go +++ b/go/mysql/vault/auth_server_vault.go @@ -28,41 +28,12 @@ import ( "time" vaultapi "github.com/aquarapid/vaultlib" - "github.com/spf13/pflag" - - "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/servenv" ) -var ( - vaultAddr string - vaultTimeout time.Duration - vaultCACert string - vaultPath string - vaultCacheTTL time.Duration - vaultTokenFile string - vaultRoleID string - vaultRoleSecretIDFile string - vaultRoleMountPoint string -) - -func init() { - servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) { - fs.StringVar(&vaultAddr, "mysql_auth_vault_addr", "", "URL to Vault server") - fs.DurationVar(&vaultTimeout, "mysql_auth_vault_timeout", 10*time.Second, "Timeout for vault API operations") - fs.StringVar(&vaultCACert, "mysql_auth_vault_tls_ca", "", "Path to CA PEM for validating Vault server certificate") - fs.StringVar(&vaultPath, "mysql_auth_vault_path", "", "Vault path to vtgate credentials JSON blob, e.g.: secret/data/prod/vtgatecreds") - fs.DurationVar(&vaultCacheTTL, "mysql_auth_vault_ttl", 30*time.Minute, "How long to cache vtgate credentials from the Vault server") - fs.StringVar(&vaultTokenFile, "mysql_auth_vault_tokenfile", "", "Path to file containing Vault auth token; token can also be passed using VAULT_TOKEN environment variable") - fs.StringVar(&vaultRoleID, "mysql_auth_vault_roleid", "", "Vault AppRole id; can also be passed using VAULT_ROLEID environment variable") - fs.StringVar(&vaultRoleSecretIDFile, "mysql_auth_vault_role_secretidfile", "", "Path to file containing Vault AppRole secret_id; can also be passed using VAULT_SECRETID environment variable") - fs.StringVar(&vaultRoleMountPoint, "mysql_auth_vault_role_mountpoint", "approle", "Vault AppRole mountpoint; can also be passed using VAULT_MOUNTPOINT environment variable") - }) -} - // AuthServerVault implements AuthServer with a config loaded from Vault. type AuthServerVault struct { methods []mysql.AuthMethod @@ -80,7 +51,7 @@ type AuthServerVault struct { } // InitAuthServerVault - entrypoint for initialization of Vault AuthServer implementation -func InitAuthServerVault() { +func InitAuthServerVault(vaultAddr string, vaultTimeout time.Duration, vaultCACert, vaultPath string, vaultCacheTTL time.Duration, vaultTokenFile, vaultRoleID, vaultRoleSecretIDFile, vaultRoleMountPoint string) { // Check critical parameters. if vaultAddr == "" { log.Infof("Not configuring AuthServerVault, as --mysql_auth_vault_addr is empty.") diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 00eb5e1b605..c3a67b1d7e1 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -74,6 +74,8 @@ var ( mysqlDefaultWorkloadName = "OLTP" mysqlDefaultWorkload int32 + + mysqlServerFlushDelay = 100 * time.Millisecond ) func registerPluginFlags(fs *pflag.FlagSet) { @@ -97,6 +99,7 @@ func registerPluginFlags(fs *pflag.FlagSet) { fs.DurationVar(&mysqlQueryTimeout, "mysql_server_query_timeout", mysqlQueryTimeout, "mysql query timeout") fs.BoolVar(&mysqlConnBufferPooling, "mysql-server-pool-conn-read-buffers", mysqlConnBufferPooling, "If set, the server will pool incoming connection read buffers") fs.DurationVar(&mysqlKeepAlivePeriod, "mysql-server-keepalive-period", mysqlKeepAlivePeriod, "TCP period between keep-alives") + fs.DurationVar(&mysqlServerFlushDelay, "mysql_server_flush_delay", mysqlServerFlushDelay, "Delay after which buffered response will be flushed to the client.") fs.StringVar(&mysqlDefaultWorkloadName, "mysql_default_workload", mysqlDefaultWorkloadName, "Default session workload (OLTP, OLAP, DBA)") } @@ -526,11 +529,12 @@ func initMySQLProtocol(vtgate *VTGate) *mysqlServer { mysqlProxyProtocol, mysqlConnBufferPooling, mysqlKeepAlivePeriod, + mysqlServerFlushDelay, + servenv.MySQLServerVersion(), ) if err != nil { log.Exitf("mysql.NewListener failed: %v", err) } - srv.tcpListener.ServerVersion = servenv.MySQLServerVersion() if mysqlSslCert != "" && mysqlSslKey != "" { tlsVersion, err := vttls.TLSVersionToNumber(mysqlTLSMinVersion) if err != nil { @@ -571,6 +575,8 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys false, mysqlConnBufferPooling, mysqlKeepAlivePeriod, + mysqlServerFlushDelay, + servenv.MySQLServerVersion(), ) switch err := err.(type) { @@ -603,6 +609,8 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys false, mysqlConnBufferPooling, mysqlKeepAlivePeriod, + mysqlServerFlushDelay, + servenv.MySQLServerVersion(), ) return listener, listenerErr default: diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 1aa201b5d4c..21375050a4d 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -348,7 +348,7 @@ func TestGracefulShutdown(t *testing.T) { vh := newVtgateHandler(&VTGate{executor: executor, timings: timings, rowsReturned: rowsReturned, rowsAffected: rowsAffected}) th := &testHandler{} - listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0) + listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0, "8.0.30-Vitess") require.NoError(t, err) defer listener.Close() @@ -378,7 +378,7 @@ func TestGracefulShutdownWithTransaction(t *testing.T) { vh := newVtgateHandler(&VTGate{executor: executor, timings: timings, rowsReturned: rowsReturned, rowsAffected: rowsAffected}) th := &testHandler{} - listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0) + listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0, "8.0.30-Vitess") require.NoError(t, err) defer listener.Close()