Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mysql: Refactor out usage of servenv #14732

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
"vitess.io/vitess/go/vt/vtgate"
)

var clientcertAuthMethod string

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerClientCert() })
Main.Flags().StringVar(&clientcertAuthMethod, "mysql_clientcert_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")

vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerClientCert(clientcertAuthMethod) })
}
13 changes: 12 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,21 @@ package cli
// This plugin imports ldapauthserver to register the LDAP implementation of AuthServer.

import (
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/ldapauthserver"
"vitess.io/vitess/go/vt/vtgate"
)

var (
ldapAuthConfigFile string
ldapAuthConfigString string
ldapAuthMethod string
)

func init() {
vtgate.RegisterPluginInitializer(func() { ldapauthserver.Init() })
Main.Flags().StringVar(&ldapAuthConfigFile, "mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.")
Main.Flags().StringVar(&ldapAuthConfigString, "mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.")
Main.Flags().StringVar(&ldapAuthMethod, "mysql_ldap_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajm188 Had to use Main.Flags() here since the on parse callbacks don't work as it's too late since cli.go inits first, but dunno if this is the right fix then? Or if there's another way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's what we're currently doing? https://github.com/vitessio/vitess/blob/main/go/cmd/vtgate/cli/cli.go#L187-L189 i'm not sure if that answers your question


vtgate.RegisterPluginInitializer(func() { ldapauthserver.Init(ldapAuthConfigFile, ldapAuthConfigString, ldapAuthMethod) })
}
16 changes: 15 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@ package cli
// This plugin imports staticauthserver to register the flat-file implementation of AuthServer.

import (
"time"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/vtgate"
)

var (
mysqlAuthServerStaticFile string
mysqlAuthServerStaticString string
mysqlAuthServerStaticReloadInterval time.Duration
)

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerStatic() })
Main.Flags().StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.")
Main.Flags().StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.")
Main.Flags().DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials")

vtgate.RegisterPluginInitializer(func() {
mysql.InitAuthServerStatic(mysqlAuthServerStaticFile, mysqlAuthServerStaticString, mysqlAuthServerStaticReloadInterval)
})
}
28 changes: 27 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,36 @@ package cli
// This plugin imports InitAuthServerVault to register the HashiCorp Vault implementation of AuthServer.

import (
"time"

"vitess.io/vitess/go/mysql/vault"
"vitess.io/vitess/go/vt/vtgate"
)

var (
vaultAddr string
vaultTimeout time.Duration
vaultCACert string
vaultPath string
vaultCacheTTL time.Duration
vaultTokenFile string
vaultRoleID string
vaultRoleSecretIDFile string
vaultRoleMountPoint string
)

func init() {
vtgate.RegisterPluginInitializer(func() { vault.InitAuthServerVault() })
Main.Flags().StringVar(&vaultAddr, "mysql_auth_vault_addr", "", "URL to Vault server")
Main.Flags().DurationVar(&vaultTimeout, "mysql_auth_vault_timeout", 10*time.Second, "Timeout for vault API operations")
Main.Flags().StringVar(&vaultCACert, "mysql_auth_vault_tls_ca", "", "Path to CA PEM for validating Vault server certificate")
Main.Flags().StringVar(&vaultPath, "mysql_auth_vault_path", "", "Vault path to vtgate credentials JSON blob, e.g.: secret/data/prod/vtgatecreds")
Main.Flags().DurationVar(&vaultCacheTTL, "mysql_auth_vault_ttl", 30*time.Minute, "How long to cache vtgate credentials from the Vault server")
Main.Flags().StringVar(&vaultTokenFile, "mysql_auth_vault_tokenfile", "", "Path to file containing Vault auth token; token can also be passed using VAULT_TOKEN environment variable")
Main.Flags().StringVar(&vaultRoleID, "mysql_auth_vault_roleid", "", "Vault AppRole id; can also be passed using VAULT_ROLEID environment variable")
Main.Flags().StringVar(&vaultRoleSecretIDFile, "mysql_auth_vault_role_secretidfile", "", "Path to file containing Vault AppRole secret_id; can also be passed using VAULT_SECRETID environment variable")
Main.Flags().StringVar(&vaultRoleMountPoint, "mysql_auth_vault_role_mountpoint", "approle", "Vault AppRole mountpoint; can also be passed using VAULT_MOUNTPOINT environment variable")

vtgate.RegisterPluginInitializer(func() {
vault.InitAuthServerVault(vaultAddr, vaultTimeout, vaultCACert, vaultPath, vaultCacheTTL, vaultTokenFile, vaultRoleID, vaultRoleSecretIDFile, vaultRoleMountPoint)
})
}
1 change: 1 addition & 0 deletions go/flags/endtoend/vtcombo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ Flags:
--mysql_default_workload string Default session workload (OLTP, OLAP, DBA) (default "OLTP")
--mysql_port int mysql port (default 3306)
--mysql_server_bind_address string Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.
--mysql_server_flush_delay duration Delay after which buffered response will be flushed to the client. (default 100ms)
--mysql_server_port int If set, also listen for MySQL binary protocol connections on this port. (default -1)
--mysql_server_query_timeout duration mysql query timeout
--mysql_server_read_timeout duration connection read timeout
Expand Down
15 changes: 3 additions & 12 deletions go/mysql/auth_server_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,16 @@ import (
"github.com/spf13/pflag"

"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/servenv"
)

var clientcertAuthMethod string

func init() {
servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
fs.StringVar(&clientcertAuthMethod, "mysql_clientcert_auth_method", string(MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")
})
}

// AuthServerClientCert implements AuthServer which enforces client side certificates
type AuthServerClientCert struct {
methods []AuthMethod
Method AuthMethodDescription
}

// InitAuthServerClientCert is public so it can be called from plugin_auth_clientcert.go (go/cmd/vtgate)
func InitAuthServerClientCert() {
func InitAuthServerClientCert(clientcertAuthMethod string) {
if pflag.CommandLine.Lookup("mysql_server_ssl_ca").Value.String() == "" {
log.Info("Not configuring AuthServerClientCert because mysql_server_ssl_ca is empty")
return
Expand All @@ -50,11 +41,11 @@ func InitAuthServerClientCert() {
log.Exitf("Invalid mysql_clientcert_auth_method value: only support mysql_clear_password or dialog")
}

ascc := newAuthServerClientCert()
ascc := newAuthServerClientCert(clientcertAuthMethod)
RegisterAuthServer("clientcert", ascc)
}

func newAuthServerClientCert() *AuthServerClientCert {
func newAuthServerClientCert(clientcertAuthMethod string) *AuthServerClientCert {
ascc := &AuthServerClientCert{
Method: AuthMethodDescription(clientcertAuthMethod),
}
Expand Down
14 changes: 4 additions & 10 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,13 @@ import (

const clientCertUsername = "Client Cert"

func init() {
// These tests do not invoke the servenv.Parse codepaths, so this default
// does not get set by the OnParseFor hook.
clientcertAuthMethod = string(MysqlClearPassword)
}

func TestValidCert(t *testing.T) {
th := &testHandler{}

authServer := newAuthServerClientCert()
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -111,10 +105,10 @@ func TestValidCert(t *testing.T) {
func TestNoCert(t *testing.T) {
th := &testHandler{}

authServer := newAuthServerClientCert()
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
21 changes: 1 addition & 20 deletions go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,34 +27,15 @@ import (
"syscall"
"time"

"github.com/spf13/pflag"

"vitess.io/vitess/go/mysql/sqlerror"

"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"

querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
)

var (
mysqlAuthServerStaticFile string
mysqlAuthServerStaticString string
mysqlAuthServerStaticReloadInterval time.Duration
mysqlServerFlushDelay = 100 * time.Millisecond
)

func init() {
servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
fs.StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.")
fs.StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.")
fs.DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials")
fs.DurationVar(&mysqlServerFlushDelay, "mysql_server_flush_delay", mysqlServerFlushDelay, "Delay after which buffered response will be flushed to the client.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag was in a very wrong place. It was in the auth plugin, but it's a MySQL protocol config flag! It was moved to the vtgate mysql plugin server instead which seems much more appropriate.

})
}

const (
localhostName = "localhost"
)
Expand Down Expand Up @@ -94,7 +75,7 @@ type AuthServerStaticEntry struct {
}

// InitAuthServerStatic Handles initializing the AuthServerStatic if necessary.
func InitAuthServerStatic() {
func InitAuthServerStatic(mysqlAuthServerStaticFile, mysqlAuthServerStaticString string, mysqlAuthServerStaticReloadInterval time.Duration) {
// Check parameters.
if mysqlAuthServerStaticFile == "" && mysqlAuthServerStaticString == "" {
// Not configured, nothing to do.
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
}

// Send the connection back, so the other side can close it.
c := newConn(conn)
c := newConn(conn, params.FlushDelay)
status <- connectResult{
c: c,
}
Expand Down
10 changes: 5 additions & 5 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestTLSClientDisabled(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -223,7 +223,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -296,7 +296,7 @@ func TestTLSClientRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -343,7 +343,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -426,7 +426,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down
14 changes: 11 additions & 3 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ import (
)

const (
DefaultFlushDelay = 100 * time.Millisecond

// connBufferSize is how much we buffer for reading and
// writing. It is also how much we allocate for ephemeral buffers.
connBufferSize = 16 * 1024
Expand Down Expand Up @@ -129,6 +131,7 @@ type Conn struct {

bufferedReader *bufio.Reader
flushTimer *time.Timer
flushDelay time.Duration
header [packetHeaderSize]byte

// Keep track of how and of the buffer we allocated for an
Expand Down Expand Up @@ -247,10 +250,14 @@ var readersPool = sync.Pool{New: func() any { return bufio.NewReaderSize(nil, co

// newConn is an internal method to create a Conn. Used by client and server
// side for common creation code.
func newConn(conn net.Conn) *Conn {
func newConn(conn net.Conn, flushDelay time.Duration) *Conn {
if flushDelay == 0 {
dbussink marked this conversation as resolved.
Show resolved Hide resolved
flushDelay = DefaultFlushDelay
}
return &Conn{
conn: conn,
bufferedReader: bufio.NewReaderSize(conn, connBufferSize),
flushDelay: flushDelay,
}
}

Expand All @@ -275,6 +282,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn {
listener: listener,
PrepareData: make(map[uint32]*PrepareData),
keepAliveOn: enabledKeepAlive,
flushDelay: listener.flushDelay,
}

if listener.connReadBufferSize > 0 {
Expand Down Expand Up @@ -348,7 +356,7 @@ func (c *Conn) returnReader() {
// startFlushTimer must be called while holding lock on bufMu.
func (c *Conn) startFlushTimer() {
if c.flushTimer == nil {
c.flushTimer = time.AfterFunc(mysqlServerFlushDelay, func() {
c.flushTimer = time.AfterFunc(c.flushDelay, func() {
c.bufMu.Lock()
defer c.bufMu.Unlock()

Expand All @@ -358,7 +366,7 @@ func (c *Conn) startFlushTimer() {
c.bufferedWriter.Flush()
})
} else {
c.flushTimer.Reset(mysqlServerFlushDelay)
c.flushTimer.Reset(c.flushDelay)
}
}

Expand Down
2 changes: 1 addition & 1 deletion go/mysql/conn_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ var _ net.Addr = (*mockAddress)(nil)

// GetTestConn returns a conn for testing purpose only.
func GetTestConn() *Conn {
return newConn(testConn{})
return newConn(testConn{}, DefaultFlushDelay)
}

// GetTestServerConn is only meant to be used for testing.
Expand Down
12 changes: 6 additions & 6 deletions go/mysql/conn_flaky_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) {
require.Nil(t, serverErr, "Accept failed: %v", serverErr)

// Create a Conn on both sides.
cConn := newConn(clientConn)
sConn := newConn(serverConn)
cConn := newConn(clientConn, DefaultFlushDelay)
sConn := newConn(serverConn, DefaultFlushDelay)
sConn.PrepareData = map[uint32]*PrepareData{}

return listener, sConn, cConn
Expand Down Expand Up @@ -942,7 +942,7 @@ func TestConnectionErrorWhileWritingComQuery(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComQuery, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)

// this handler will return an error on the first run, and fail the test if it's run more times
errorString := make([]byte, 17000)
Expand All @@ -958,7 +958,7 @@ func TestConnectionErrorWhileWritingComStmtSendLongData(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComStmtSendLongData, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)

// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
Expand All @@ -972,7 +972,7 @@ func TestConnectionErrorWhileWritingComPrepare(t *testing.T) {
writeToPass: []bool{false},
pos: -1,
queryPacket: []byte{0x01, 0x00, 0x00, 0x00, ComPrepare},
})
}, DefaultFlushDelay)
sConn.Capabilities = sConn.Capabilities | CapabilityClientMultiStatements
// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
Expand All @@ -987,7 +987,7 @@ func TestConnectionErrorWhileWritingComStmtExecute(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComStmtExecute, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)
// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
res := sConn.handleNextCommand(handler)
Expand Down
Loading
Loading