Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mysql: Ensure we set up the initial collation correctly #15115

Merged
merged 2 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestValidCert(t *testing.T) {
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestNoCert(t *testing.T) {
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
10 changes: 5 additions & 5 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestTLSClientDisabled(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -223,7 +223,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -296,7 +296,7 @@ func TestTLSClientRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -343,7 +343,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -426,7 +426,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down
6 changes: 3 additions & 3 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
var queries []string
if c.Capabilities&CapabilityClientMultiStatements != 0 {
var err error
queries, err = handler.SQLParser().SplitStatementToPieces(query)
queries, err = handler.Env().Parser().SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
return c.writeErrorPacketFromErrorAndLog(err)
Expand All @@ -1256,7 +1256,7 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
PrepareStmt: queries[0],
}

statement, err := handler.SQLParser().ParseStrictDDL(query)
statement, err := handler.Env().Parser().ParseStrictDDL(query)
if err != nil {
log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err)
if !c.writeErrorPacketFromErrorAndLog(err) {
Expand Down Expand Up @@ -1364,7 +1364,7 @@ func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) {
var queries []string
var err error
if c.Capabilities&CapabilityClientMultiStatements != 0 {
queries, err = handler.SQLParser().SplitStatementToPieces(query)
queries, err = handler.Env().Parser().SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
return c.writeErrorPacketFromErrorAndLog(err)
Expand Down
6 changes: 3 additions & 3 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import (
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/utils"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtenv"
)

func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) {
Expand Down Expand Up @@ -1141,8 +1141,8 @@ func (t testRun) WarningCount(c *Conn) uint16 {
return 0
}

func (t testRun) SQLParser() *sqlparser.Parser {
return sqlparser.NewTestParser()
func (t testRun) Env() *vtenv.Environment {
return vtenv.NewTestEnv()
}

var _ Handler = (*testRun)(nil)
12 changes: 6 additions & 6 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ import (
"time"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/config"
"vitess.io/vitess/go/mysql/replication"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtenv"
)

const appendEntry = -1
Expand Down Expand Up @@ -127,7 +127,7 @@ type DB struct {
lastErrorMu sync.Mutex
lastError error

parser *sqlparser.Parser
env *vtenv.Environment
}

// QueryHandler is the interface used by the DB to simulate executed queries
Expand Down Expand Up @@ -181,15 +181,15 @@ func New(t testing.TB) *DB {
queryPatternUserCallback: make(map[*regexp.Regexp]func(string)),
patternData: make(map[string]exprResult),
lastErrorMu: sync.Mutex{},
parser: sqlparser.NewTestParser(),
env: vtenv.NewTestEnv(),
}

db.Handler = db

authServer := mysql.NewAuthServerNone()

// Start listening.
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0, fmt.Sprintf("%s-Vitess", config.DefaultMySQLVersion), 0)
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -845,6 +845,6 @@ func (db *DB) GetQueryPatternResult(key string) (func(string), ExpectedResult, b
return nil, ExpectedResult{nil, nil}, false, nil
}

func (db *DB) SQLParser() *sqlparser.Parser {
return db.parser
func (db *DB) Env() *vtenv.Environment {
return db.env
}
7 changes: 5 additions & 2 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/test/utils"

"vitess.io/vitess/go/vt/tlstest"
Expand All @@ -45,7 +46,7 @@ func TestClearTextClientAuth(t *testing.T) {
defer authServer.close()

// Create the listener.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -77,6 +78,7 @@ func TestClearTextClientAuth(t *testing.T) {

defer conn.Close()

assert.Equal(t, collations.ID(collations.CollationUtf8mb4ID), conn.CharacterSet)
// Run a 'select rows' command with results.
result, err := conn.ExecuteFetch("select rows", 10000, true)
require.NoError(t, err, "ExecuteFetch failed: %v", err)
Expand All @@ -99,7 +101,7 @@ func TestSSLConnection(t *testing.T) {
defer authServer.close()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -176,6 +178,7 @@ func testSSLConnectionBasics(t *testing.T, params *ConnParams) {

defer conn.Close()
assert.Equal(t, "user1", conn.User, "Invalid conn.User, got %v was expecting user1", conn.User)
assert.Equal(t, collations.ID(collations.CollationUtf8mb4ID), conn.CharacterSet)

// Run a 'select rows' command with results.
result, err := conn.ExecuteFetch("select rows", 10000, true)
Expand Down
22 changes: 7 additions & 15 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtenv"
"vitess.io/vitess/go/vt/vterrors"
)

Expand Down Expand Up @@ -133,7 +133,7 @@

ComResetConnection(c *Conn)

SQLParser() *sqlparser.Parser
Env() *vtenv.Environment
}

// UnimplementedHandler implemnts all of the optional callbacks so as to satisy
Expand Down Expand Up @@ -233,9 +233,6 @@
connBufferPooling bool,
keepAlivePeriod time.Duration,
flushDelay time.Duration,
mysqlServerVersion string,
truncateErrLen int,

) (*Listener, error) {
cfg := ListenerConfig{
Listener: l,
Expand All @@ -247,8 +244,6 @@
ConnBufferPooling: connBufferPooling,
ConnKeepAlivePeriod: keepAlivePeriod,
FlushDelay: flushDelay,
MySQLServerVersion: mysqlServerVersion,
TruncateErrLen: truncateErrLen,
}
return NewListenerWithConfig(cfg)
}
Expand All @@ -264,19 +259,17 @@
connBufferPooling bool,
keepAlivePeriod time.Duration,
flushDelay time.Duration,
mysqlServerVersion string,
truncateErrLen int,
) (*Listener, error) {
listener, err := net.Listen(protocol, address)
if err != nil {
return nil, err
}
if proxyProtocol {
proxyListener := &proxyproto.Listener{Listener: listener}
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, mysqlServerVersion, truncateErrLen)
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay)

Check warning on line 269 in go/mysql/server.go

View check run for this annotation

Codecov / codecov/patch

go/mysql/server.go#L269

Added line #L269 was not covered by tests
}

return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, mysqlServerVersion, truncateErrLen)
return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay)
}

// ListenerConfig should be used with NewListenerWithConfig to specify listener parameters.
Expand All @@ -293,8 +286,6 @@
ConnBufferPooling bool
ConnKeepAlivePeriod time.Duration
FlushDelay time.Duration
MySQLServerVersion string
TruncateErrLen int
}

// NewListenerWithConfig creates new listener using provided config. There are
Expand All @@ -315,15 +306,16 @@
authServer: cfg.AuthServer,
handler: cfg.Handler,
listener: l,
ServerVersion: cfg.MySQLServerVersion,
ServerVersion: cfg.Handler.Env().MySQLVersion(),
Copy link
Contributor

@mattlord mattlord Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can Handler and/or Env() be nil here? IMO it's worth a safety check here that returns an error or even an explicit panic with a message if it really shouldn't happen and we can't proceed. Same for the cfg.Handler chains below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should never ever be nil under any circumstance. That it panics in that case is good. I don't think we should guard this, as it means many guards all over the place then (the env is used in many other places too).

connectionID: 1,
connReadTimeout: cfg.ConnReadTimeout,
connWriteTimeout: cfg.ConnWriteTimeout,
connReadBufferSize: cfg.ConnReadBufferSize,
connBufferPooling: cfg.ConnBufferPooling,
connKeepAlivePeriod: cfg.ConnKeepAlivePeriod,
flushDelay: cfg.FlushDelay,
truncateErrLen: cfg.TruncateErrLen,
truncateErrLen: cfg.Handler.Env().TruncateErrLen(),
charset: cfg.Handler.Env().CollationEnv().DefaultConnectionCharset(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}, nil
}

Expand Down
Loading
Loading