Skip to content

Commit

Permalink
mysql: Refactor out usage of servenv (#14732)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Dec 8, 2023
1 parent 27f1ac2 commit ab1ba2e
Show file tree
Hide file tree
Showing 23 changed files with 157 additions and 148 deletions.
6 changes: 5 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
"vitess.io/vitess/go/vt/vtgate"
)

var clientcertAuthMethod string

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerClientCert() })
Main.Flags().StringVar(&clientcertAuthMethod, "mysql_clientcert_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")

vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerClientCert(clientcertAuthMethod) })
}
13 changes: 12 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,21 @@ package cli
// This plugin imports ldapauthserver to register the LDAP implementation of AuthServer.

import (
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/ldapauthserver"
"vitess.io/vitess/go/vt/vtgate"
)

var (
ldapAuthConfigFile string
ldapAuthConfigString string
ldapAuthMethod string
)

func init() {
vtgate.RegisterPluginInitializer(func() { ldapauthserver.Init() })
Main.Flags().StringVar(&ldapAuthConfigFile, "mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.")
Main.Flags().StringVar(&ldapAuthConfigString, "mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.")
Main.Flags().StringVar(&ldapAuthMethod, "mysql_ldap_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")

vtgate.RegisterPluginInitializer(func() { ldapauthserver.Init(ldapAuthConfigFile, ldapAuthConfigString, ldapAuthMethod) })
}
16 changes: 15 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@ package cli
// This plugin imports staticauthserver to register the flat-file implementation of AuthServer.

import (
"time"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/vtgate"
)

var (
mysqlAuthServerStaticFile string
mysqlAuthServerStaticString string
mysqlAuthServerStaticReloadInterval time.Duration
)

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerStatic() })
Main.Flags().StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.")
Main.Flags().StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.")
Main.Flags().DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials")

vtgate.RegisterPluginInitializer(func() {
mysql.InitAuthServerStatic(mysqlAuthServerStaticFile, mysqlAuthServerStaticString, mysqlAuthServerStaticReloadInterval)
})
}
28 changes: 27 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,36 @@ package cli
// This plugin imports InitAuthServerVault to register the HashiCorp Vault implementation of AuthServer.

import (
"time"

"vitess.io/vitess/go/mysql/vault"
"vitess.io/vitess/go/vt/vtgate"
)

var (
vaultAddr string
vaultTimeout time.Duration
vaultCACert string
vaultPath string
vaultCacheTTL time.Duration
vaultTokenFile string
vaultRoleID string
vaultRoleSecretIDFile string
vaultRoleMountPoint string
)

func init() {
vtgate.RegisterPluginInitializer(func() { vault.InitAuthServerVault() })
Main.Flags().StringVar(&vaultAddr, "mysql_auth_vault_addr", "", "URL to Vault server")
Main.Flags().DurationVar(&vaultTimeout, "mysql_auth_vault_timeout", 10*time.Second, "Timeout for vault API operations")
Main.Flags().StringVar(&vaultCACert, "mysql_auth_vault_tls_ca", "", "Path to CA PEM for validating Vault server certificate")
Main.Flags().StringVar(&vaultPath, "mysql_auth_vault_path", "", "Vault path to vtgate credentials JSON blob, e.g.: secret/data/prod/vtgatecreds")
Main.Flags().DurationVar(&vaultCacheTTL, "mysql_auth_vault_ttl", 30*time.Minute, "How long to cache vtgate credentials from the Vault server")
Main.Flags().StringVar(&vaultTokenFile, "mysql_auth_vault_tokenfile", "", "Path to file containing Vault auth token; token can also be passed using VAULT_TOKEN environment variable")
Main.Flags().StringVar(&vaultRoleID, "mysql_auth_vault_roleid", "", "Vault AppRole id; can also be passed using VAULT_ROLEID environment variable")
Main.Flags().StringVar(&vaultRoleSecretIDFile, "mysql_auth_vault_role_secretidfile", "", "Path to file containing Vault AppRole secret_id; can also be passed using VAULT_SECRETID environment variable")
Main.Flags().StringVar(&vaultRoleMountPoint, "mysql_auth_vault_role_mountpoint", "approle", "Vault AppRole mountpoint; can also be passed using VAULT_MOUNTPOINT environment variable")

vtgate.RegisterPluginInitializer(func() {
vault.InitAuthServerVault(vaultAddr, vaultTimeout, vaultCACert, vaultPath, vaultCacheTTL, vaultTokenFile, vaultRoleID, vaultRoleSecretIDFile, vaultRoleMountPoint)
})
}
1 change: 1 addition & 0 deletions go/flags/endtoend/vtcombo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ Flags:
--mysql_default_workload string Default session workload (OLTP, OLAP, DBA) (default "OLTP")
--mysql_port int mysql port (default 3306)
--mysql_server_bind_address string Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.
--mysql_server_flush_delay duration Delay after which buffered response will be flushed to the client. (default 100ms)
--mysql_server_port int If set, also listen for MySQL binary protocol connections on this port. (default -1)
--mysql_server_query_timeout duration mysql query timeout
--mysql_server_read_timeout duration connection read timeout
Expand Down
15 changes: 3 additions & 12 deletions go/mysql/auth_server_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,16 @@ import (
"github.com/spf13/pflag"

"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/servenv"
)

var clientcertAuthMethod string

func init() {
servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
fs.StringVar(&clientcertAuthMethod, "mysql_clientcert_auth_method", string(MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")
})
}

// AuthServerClientCert implements AuthServer which enforces client side certificates
type AuthServerClientCert struct {
methods []AuthMethod
Method AuthMethodDescription
}

// InitAuthServerClientCert is public so it can be called from plugin_auth_clientcert.go (go/cmd/vtgate)
func InitAuthServerClientCert() {
func InitAuthServerClientCert(clientcertAuthMethod string) {
if pflag.CommandLine.Lookup("mysql_server_ssl_ca").Value.String() == "" {
log.Info("Not configuring AuthServerClientCert because mysql_server_ssl_ca is empty")
return
Expand All @@ -50,11 +41,11 @@ func InitAuthServerClientCert() {
log.Exitf("Invalid mysql_clientcert_auth_method value: only support mysql_clear_password or dialog")
}

ascc := newAuthServerClientCert()
ascc := newAuthServerClientCert(clientcertAuthMethod)
RegisterAuthServer("clientcert", ascc)
}

func newAuthServerClientCert() *AuthServerClientCert {
func newAuthServerClientCert(clientcertAuthMethod string) *AuthServerClientCert {
ascc := &AuthServerClientCert{
Method: AuthMethodDescription(clientcertAuthMethod),
}
Expand Down
14 changes: 4 additions & 10 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,13 @@ import (

const clientCertUsername = "Client Cert"

func init() {
// These tests do not invoke the servenv.Parse codepaths, so this default
// does not get set by the OnParseFor hook.
clientcertAuthMethod = string(MysqlClearPassword)
}

func TestValidCert(t *testing.T) {
th := &testHandler{}

authServer := newAuthServerClientCert()
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -111,10 +105,10 @@ func TestValidCert(t *testing.T) {
func TestNoCert(t *testing.T) {
th := &testHandler{}

authServer := newAuthServerClientCert()
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
21 changes: 1 addition & 20 deletions go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,34 +27,15 @@ import (
"syscall"
"time"

"github.com/spf13/pflag"

"vitess.io/vitess/go/mysql/sqlerror"

"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"

querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
)

var (
mysqlAuthServerStaticFile string
mysqlAuthServerStaticString string
mysqlAuthServerStaticReloadInterval time.Duration
mysqlServerFlushDelay = 100 * time.Millisecond
)

func init() {
servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
fs.StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.")
fs.StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.")
fs.DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials")
fs.DurationVar(&mysqlServerFlushDelay, "mysql_server_flush_delay", mysqlServerFlushDelay, "Delay after which buffered response will be flushed to the client.")
})
}

const (
localhostName = "localhost"
)
Expand Down Expand Up @@ -94,7 +75,7 @@ type AuthServerStaticEntry struct {
}

// InitAuthServerStatic Handles initializing the AuthServerStatic if necessary.
func InitAuthServerStatic() {
func InitAuthServerStatic(mysqlAuthServerStaticFile, mysqlAuthServerStaticString string, mysqlAuthServerStaticReloadInterval time.Duration) {
// Check parameters.
if mysqlAuthServerStaticFile == "" && mysqlAuthServerStaticString == "" {
// Not configured, nothing to do.
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
}

// Send the connection back, so the other side can close it.
c := newConn(conn)
c := newConn(conn, params.FlushDelay)
status <- connectResult{
c: c,
}
Expand Down
10 changes: 5 additions & 5 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestTLSClientDisabled(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -223,7 +223,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -296,7 +296,7 @@ func TestTLSClientRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -343,7 +343,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -426,7 +426,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down
14 changes: 11 additions & 3 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ import (
)

const (
DefaultFlushDelay = 100 * time.Millisecond

// connBufferSize is how much we buffer for reading and
// writing. It is also how much we allocate for ephemeral buffers.
connBufferSize = 16 * 1024
Expand Down Expand Up @@ -128,6 +130,7 @@ type Conn struct {

bufferedReader *bufio.Reader
flushTimer *time.Timer
flushDelay time.Duration
header [packetHeaderSize]byte

// Keep track of how and of the buffer we allocated for an
Expand Down Expand Up @@ -246,10 +249,14 @@ var readersPool = sync.Pool{New: func() any { return bufio.NewReaderSize(nil, co

// newConn is an internal method to create a Conn. Used by client and server
// side for common creation code.
func newConn(conn net.Conn) *Conn {
func newConn(conn net.Conn, flushDelay time.Duration) *Conn {
if flushDelay == 0 {
flushDelay = DefaultFlushDelay
}
return &Conn{
conn: conn,
bufferedReader: bufio.NewReaderSize(conn, connBufferSize),
flushDelay: flushDelay,
}
}

Expand All @@ -274,6 +281,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn {
listener: listener,
PrepareData: make(map[uint32]*PrepareData),
keepAliveOn: enabledKeepAlive,
flushDelay: listener.flushDelay,
}

if listener.connReadBufferSize > 0 {
Expand Down Expand Up @@ -347,7 +355,7 @@ func (c *Conn) returnReader() {
// startFlushTimer must be called while holding lock on bufMu.
func (c *Conn) startFlushTimer() {
if c.flushTimer == nil {
c.flushTimer = time.AfterFunc(mysqlServerFlushDelay, func() {
c.flushTimer = time.AfterFunc(c.flushDelay, func() {
c.bufMu.Lock()
defer c.bufMu.Unlock()

Expand All @@ -357,7 +365,7 @@ func (c *Conn) startFlushTimer() {
c.bufferedWriter.Flush()
})
} else {
c.flushTimer.Reset(mysqlServerFlushDelay)
c.flushTimer.Reset(c.flushDelay)
}
}

Expand Down
2 changes: 1 addition & 1 deletion go/mysql/conn_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ var _ net.Addr = (*mockAddress)(nil)

// GetTestConn returns a conn for testing purpose only.
func GetTestConn() *Conn {
return newConn(testConn{})
return newConn(testConn{}, DefaultFlushDelay)
}

// GetTestServerConn is only meant to be used for testing.
Expand Down
12 changes: 6 additions & 6 deletions go/mysql/conn_flaky_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) {
require.Nil(t, serverErr, "Accept failed: %v", serverErr)

// Create a Conn on both sides.
cConn := newConn(clientConn)
sConn := newConn(serverConn)
cConn := newConn(clientConn, DefaultFlushDelay)
sConn := newConn(serverConn, DefaultFlushDelay)
sConn.PrepareData = map[uint32]*PrepareData{}

return listener, sConn, cConn
Expand Down Expand Up @@ -942,7 +942,7 @@ func TestConnectionErrorWhileWritingComQuery(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComQuery, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)

// this handler will return an error on the first run, and fail the test if it's run more times
errorString := make([]byte, 17000)
Expand All @@ -958,7 +958,7 @@ func TestConnectionErrorWhileWritingComStmtSendLongData(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComStmtSendLongData, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)

// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
Expand All @@ -972,7 +972,7 @@ func TestConnectionErrorWhileWritingComPrepare(t *testing.T) {
writeToPass: []bool{false},
pos: -1,
queryPacket: []byte{0x01, 0x00, 0x00, 0x00, ComPrepare},
})
}, DefaultFlushDelay)
sConn.Capabilities = sConn.Capabilities | CapabilityClientMultiStatements
// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
Expand All @@ -987,7 +987,7 @@ func TestConnectionErrorWhileWritingComStmtExecute(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComStmtExecute, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)
// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
res := sConn.handleNextCommand(handler)
Expand Down
Loading

0 comments on commit ab1ba2e

Please sign in to comment.