Skip to content

Commit

Permalink
mysql: Ensure we set up the initial collation correctly
Browse files Browse the repository at this point in the history
The collation env refactor missed a bunch of setup for the MySQL
listener, which resulted in a missing initial collation on the
connection.

This fixes that and adds additional assertions as well.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed Feb 2, 2024
1 parent 6e7645c commit d21ab02
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 79 deletions.
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestValidCert(t *testing.T) {
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestNoCert(t *testing.T) {
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
10 changes: 5 additions & 5 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestTLSClientDisabled(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -223,7 +223,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -296,7 +296,7 @@ func TestTLSClientRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -343,7 +343,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -426,7 +426,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err)
defer l.Close()

Expand Down
6 changes: 3 additions & 3 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
var queries []string
if c.Capabilities&CapabilityClientMultiStatements != 0 {
var err error
queries, err = handler.SQLParser().SplitStatementToPieces(query)
queries, err = handler.Env().Parser().SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
return c.writeErrorPacketFromErrorAndLog(err)
Expand All @@ -1256,7 +1256,7 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
PrepareStmt: queries[0],
}

statement, err := handler.SQLParser().ParseStrictDDL(query)
statement, err := handler.Env().Parser().ParseStrictDDL(query)
if err != nil {
log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err)
if !c.writeErrorPacketFromErrorAndLog(err) {
Expand Down Expand Up @@ -1364,7 +1364,7 @@ func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) {
var queries []string
var err error
if c.Capabilities&CapabilityClientMultiStatements != 0 {
queries, err = handler.SQLParser().SplitStatementToPieces(query)
queries, err = handler.Env().Parser().SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
return c.writeErrorPacketFromErrorAndLog(err)
Expand Down
6 changes: 3 additions & 3 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import (
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/utils"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtenv"
)

func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) {
Expand Down Expand Up @@ -1141,8 +1141,8 @@ func (t testRun) WarningCount(c *Conn) uint16 {
return 0
}

func (t testRun) SQLParser() *sqlparser.Parser {
return sqlparser.NewTestParser()
func (t testRun) Env() *vtenv.Environment {
return vtenv.NewTestEnv()
}

var _ Handler = (*testRun)(nil)
12 changes: 6 additions & 6 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ import (
"time"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/config"
"vitess.io/vitess/go/mysql/replication"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtenv"
)

const appendEntry = -1
Expand Down Expand Up @@ -127,7 +127,7 @@ type DB struct {
lastErrorMu sync.Mutex
lastError error

parser *sqlparser.Parser
env *vtenv.Environment
}

// QueryHandler is the interface used by the DB to simulate executed queries
Expand Down Expand Up @@ -181,15 +181,15 @@ func New(t testing.TB) *DB {
queryPatternUserCallback: make(map[*regexp.Regexp]func(string)),
patternData: make(map[string]exprResult),
lastErrorMu: sync.Mutex{},
parser: sqlparser.NewTestParser(),
env: vtenv.NewTestEnv(),
}

db.Handler = db

authServer := mysql.NewAuthServerNone()

// Start listening.
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0, fmt.Sprintf("%s-Vitess", config.DefaultMySQLVersion), 0)
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -845,6 +845,6 @@ func (db *DB) GetQueryPatternResult(key string) (func(string), ExpectedResult, b
return nil, ExpectedResult{nil, nil}, false, nil
}

func (db *DB) SQLParser() *sqlparser.Parser {
return db.parser
func (db *DB) Env() *vtenv.Environment {
return db.env
}
7 changes: 5 additions & 2 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/test/utils"

"vitess.io/vitess/go/vt/tlstest"
Expand All @@ -45,7 +46,7 @@ func TestClearTextClientAuth(t *testing.T) {
defer authServer.close()

// Create the listener.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -77,6 +78,7 @@ func TestClearTextClientAuth(t *testing.T) {

defer conn.Close()

assert.Equal(t, collations.ID(collations.CollationUtf8mb4ID), conn.CharacterSet)
// Run a 'select rows' command with results.
result, err := conn.ExecuteFetch("select rows", 10000, true)
require.NoError(t, err, "ExecuteFetch failed: %v", err)
Expand All @@ -99,7 +101,7 @@ func TestSSLConnection(t *testing.T) {
defer authServer.close()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -176,6 +178,7 @@ func testSSLConnectionBasics(t *testing.T, params *ConnParams) {

defer conn.Close()
assert.Equal(t, "user1", conn.User, "Invalid conn.User, got %v was expecting user1", conn.User)
assert.Equal(t, collations.ID(collations.CollationUtf8mb4ID), conn.CharacterSet)

// Run a 'select rows' command with results.
result, err := conn.ExecuteFetch("select rows", 10000, true)
Expand Down
22 changes: 7 additions & 15 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import (
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtenv"
"vitess.io/vitess/go/vt/vterrors"
)

Expand Down Expand Up @@ -133,7 +133,7 @@ type Handler interface {

ComResetConnection(c *Conn)

SQLParser() *sqlparser.Parser
Env() *vtenv.Environment
}

// UnimplementedHandler implemnts all of the optional callbacks so as to satisy
Expand Down Expand Up @@ -233,9 +233,6 @@ func NewFromListener(
connBufferPooling bool,
keepAlivePeriod time.Duration,
flushDelay time.Duration,
mysqlServerVersion string,
truncateErrLen int,

) (*Listener, error) {
cfg := ListenerConfig{
Listener: l,
Expand All @@ -247,8 +244,6 @@ func NewFromListener(
ConnBufferPooling: connBufferPooling,
ConnKeepAlivePeriod: keepAlivePeriod,
FlushDelay: flushDelay,
MySQLServerVersion: mysqlServerVersion,
TruncateErrLen: truncateErrLen,
}
return NewListenerWithConfig(cfg)
}
Expand All @@ -264,19 +259,17 @@ func NewListener(
connBufferPooling bool,
keepAlivePeriod time.Duration,
flushDelay time.Duration,
mysqlServerVersion string,
truncateErrLen int,
) (*Listener, error) {
listener, err := net.Listen(protocol, address)
if err != nil {
return nil, err
}
if proxyProtocol {
proxyListener := &proxyproto.Listener{Listener: listener}
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, mysqlServerVersion, truncateErrLen)
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay)
}

return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, mysqlServerVersion, truncateErrLen)
return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay)
}

// ListenerConfig should be used with NewListenerWithConfig to specify listener parameters.
Expand All @@ -293,8 +286,6 @@ type ListenerConfig struct {
ConnBufferPooling bool
ConnKeepAlivePeriod time.Duration
FlushDelay time.Duration
MySQLServerVersion string
TruncateErrLen int
}

// NewListenerWithConfig creates new listener using provided config. There are
Expand All @@ -315,15 +306,16 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) {
authServer: cfg.AuthServer,
handler: cfg.Handler,
listener: l,
ServerVersion: cfg.MySQLServerVersion,
ServerVersion: cfg.Handler.Env().MySQLVersion(),
connectionID: 1,
connReadTimeout: cfg.ConnReadTimeout,
connWriteTimeout: cfg.ConnWriteTimeout,
connReadBufferSize: cfg.ConnReadBufferSize,
connBufferPooling: cfg.ConnBufferPooling,
connKeepAlivePeriod: cfg.ConnKeepAlivePeriod,
flushDelay: cfg.FlushDelay,
truncateErrLen: cfg.TruncateErrLen,
truncateErrLen: cfg.Handler.Env().TruncateErrLen(),
charset: cfg.Handler.Env().CollationEnv().DefaultConnectionCharset(),
}, nil
}

Expand Down
Loading

0 comments on commit d21ab02

Please sign in to comment.