From 55cdd40a47af93bf1d4cc282c7f8c78e6734f574 Mon Sep 17 00:00:00 2001 From: Michael Demmer Date: Wed, 1 Nov 2023 17:37:06 -0700 Subject: [PATCH] basic skeleton of a working proxy --- go/vt/vtgateproxy/mysql_server.go | 200 +++++++++++++++--------------- go/vt/vtgateproxy/vtgateproxy.go | 66 ++++++++-- 2 files changed, 160 insertions(+), 106 deletions(-) diff --git a/go/vt/vtgateproxy/mysql_server.go b/go/vt/vtgateproxy/mysql_server.go index 4a6eb0132d5..856512add00 100644 --- a/go/vt/vtgateproxy/mysql_server.go +++ b/go/vt/vtgateproxy/mysql_server.go @@ -32,6 +32,7 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/vtgateconn" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" @@ -43,10 +44,7 @@ import ( "vitess.io/vitess/go/vt/vttls" querypb "vitess.io/vitess/go/vt/proto/query" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - - "github.com/google/uuid" ) var ( @@ -87,47 +85,35 @@ type proxyHandler struct { mysql.UnimplementedHandler mu sync.Mutex - proxy *VTGateProxy - connections map[*mysql.Conn]bool + proxy *VTGateProxy } func newProxyHandler(proxy *VTGateProxy) *proxyHandler { return &proxyHandler{ - proxy: proxy, - connections: make(map[*mysql.Conn]bool), + proxy: proxy, } } -func (vh *proxyHandler) NewConnection(c *mysql.Conn) { - vh.mu.Lock() - defer vh.mu.Unlock() - vh.connections[c] = true -} - -func (vh *proxyHandler) numConnections() int { - vh.mu.Lock() - defer vh.mu.Unlock() - return len(vh.connections) +func (ph *proxyHandler) NewConnection(c *mysql.Conn) { } -func (vh *proxyHandler) ComResetConnection(c *mysql.Conn) { +func (ph *proxyHandler) ComResetConnection(c *mysql.Conn) { ctx := context.Background() - session := vh.session(c) - if session.InTransaction { + session := ph.session(c) + if session.SessionPb().InTransaction { defer atomic.AddInt32(&busyConnections, -1) } - err := vh.proxy.CloseSession(ctx, session) + err := ph.proxy.CloseSession(ctx, session) if err != nil { log.Errorf("Error happened in transaction rollback: %v", err) } } -func (vh *proxyHandler) ConnectionClosed(c *mysql.Conn) { +func (ph *proxyHandler) ConnectionClosed(c *mysql.Conn) { // Rollback if there is an ongoing transaction. Ignore error. defer func() { - vh.mu.Lock() - defer vh.mu.Unlock() - delete(vh.connections, c) + ph.mu.Lock() + defer ph.mu.Unlock() }() var ctx context.Context @@ -138,11 +124,11 @@ func (vh *proxyHandler) ConnectionClosed(c *mysql.Conn) { } else { ctx = context.Background() } - session := vh.session(c) - if session.InTransaction { + session := ph.session(c) + if session.SessionPb().InTransaction { defer atomic.AddInt32(&busyConnections, -1) } - _ = vh.proxy.CloseSession(ctx, session) + _ = ph.proxy.CloseSession(ctx, session) } // Regexp to extract parent span id over the sql query @@ -179,7 +165,7 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co return startSpanTestable(ctx, query, label, trace.NewSpan, trace.NewFromString) } -func (vh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { +func (ph *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { ctx := context.Background() var cancel context.CancelFunc if *mysqlQueryTimeout != 0 { @@ -207,21 +193,26 @@ func (vh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sql "VTGate MySQL Connector" /* subcomponent: part of the client */) ctx = callerid.NewContext(ctx, ef, im) - session := vh.session(c) - if !session.InTransaction { + session := ph.session(c) + if session != nil && !session.SessionPb().InTransaction { atomic.AddInt32(&busyConnections, 1) } defer func() { - if !session.InTransaction { + if session == nil || !session.SessionPb().InTransaction { atomic.AddInt32(&busyConnections, -1) } }() - if session.Options.Workload == querypb.ExecuteOptions_OLAP { - err := vh.proxy.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback) - return mysql.NewSQLErrorFromError(err) - } - session, result, err := vh.proxy.Execute(ctx, session, query, make(map[string]*querypb.BindVariable)) + /* + XXX/demmer figure out OLAP + + if session.Options.Workload == querypb.ExecuteOptions_OLAP { + err := ph.proxy.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback) + return mysql.NewSQLErrorFromError(err) + } + */ + + result, err := ph.proxy.Execute(ctx, session, query, make(map[string]*querypb.BindVariable)) if err := mysql.NewSQLErrorFromError(err); err != nil { return err @@ -230,13 +221,13 @@ func (vh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sql return callback(result) } -func fillInTxStatusFlags(c *mysql.Conn, session *vtgatepb.Session) { - if session.InTransaction { +func fillInTxStatusFlags(c *mysql.Conn, session *vtgateconn.VTGateSession) { + if session.SessionPb().InTransaction { c.StatusFlags |= mysql.ServerStatusInTrans } else { c.StatusFlags &= mysql.NoServerStatusInTrans } - if session.Autocommit { + if session.SessionPb().Autocommit { c.StatusFlags |= mysql.ServerStatusAutocommit } else { c.StatusFlags &= mysql.NoServerStatusAutocommit @@ -244,7 +235,7 @@ func fillInTxStatusFlags(c *mysql.Conn, session *vtgatepb.Session) { } // ComPrepare is the handler for command prepare. -func (vh *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { +func (ph *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { var ctx context.Context var cancel context.CancelFunc if *mysqlQueryTimeout != 0 { @@ -268,17 +259,17 @@ func (vh *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[str "VTGateProxy MySQL Connector" /* subcomponent: part of the client */) ctx = callerid.NewContext(ctx, ef, im) - session := vh.session(c) - if !session.InTransaction { + session := ph.session(c) + if !session.SessionPb().InTransaction { atomic.AddInt32(&busyConnections, 1) } defer func() { - if !session.InTransaction { + if !session.SessionPb().InTransaction { atomic.AddInt32(&busyConnections, -1) } }() - session, fld, err := vh.proxy.Prepare(ctx, session, query, bindVars) + session, fld, err := ph.proxy.Prepare(ctx, session, query, bindVars) err = mysql.NewSQLErrorFromError(err) if err != nil { return nil, err @@ -286,7 +277,7 @@ func (vh *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[str return fld, nil } -func (vh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { +func (ph *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { var ctx context.Context var cancel context.CancelFunc if *mysqlQueryTimeout != 0 { @@ -310,21 +301,25 @@ func (vh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData "VTGateProxy MySQL Connector" /* subcomponent: part of the client */) ctx = callerid.NewContext(ctx, ef, im) - session := vh.session(c) - if !session.InTransaction { + session := ph.session(c) + if !session.SessionPb().InTransaction { atomic.AddInt32(&busyConnections, 1) } defer func() { - if !session.InTransaction { + if !session.SessionPb().InTransaction { atomic.AddInt32(&busyConnections, -1) } }() - if session.Options.Workload == querypb.ExecuteOptions_OLAP { - err := vh.proxy.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback) - return mysql.NewSQLErrorFromError(err) - } - _, qr, err := vh.proxy.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars) + /* + XXX/demmer figure out OLAP + if session.Options.Workload == querypb.ExecuteOptions_OLAP { + err := ph.proxy.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback) + return mysql.NewSQLErrorFromError(err) + } + */ + + qr, err := ph.proxy.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars) if err != nil { err = mysql.NewSQLErrorFromError(err) return err @@ -334,43 +329,45 @@ func (vh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData return callback(qr) } -func (vh *proxyHandler) WarningCount(c *mysql.Conn) uint16 { - return uint16(len(vh.session(c).GetWarnings())) +func (ph *proxyHandler) WarningCount(c *mysql.Conn) uint16 { + return uint16(len(ph.session(c).SessionPb().GetWarnings())) } // ComBinlogDumpGTID is part of the mysql.Handler interface. -func (vh *proxyHandler) ComBinlogDumpGTID(c *mysql.Conn, gtidSet mysql.GTIDSet) error { +func (ph *proxyHandler) ComBinlogDumpGTID(c *mysql.Conn, gtidSet mysql.GTIDSet) error { return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "ComBinlogDumpGTID") } -func (vh *proxyHandler) session(c *mysql.Conn) *vtgatepb.Session { - session, _ := c.ClientData.(*vtgatepb.Session) +func (ph *proxyHandler) session(c *mysql.Conn) *vtgateconn.VTGateSession { + session, _ := c.ClientData.(*vtgateconn.VTGateSession) if session == nil { - u, _ := uuid.NewUUID() - session = &vtgatepb.Session{ - Options: &querypb.ExecuteOptions{ - IncludedFields: querypb.ExecuteOptions_ALL, - Workload: querypb.ExecuteOptions_Workload(mysqlDefaultWorkload), - - // The collation field of ExecuteOption is set right before an execution. - }, - Autocommit: true, - DDLStrategy: *defaultDDLStrategy, - SessionUUID: u.String(), - EnableSystemSettings: *sysVarSetEnabled, + options := &querypb.ExecuteOptions{ + IncludedFields: querypb.ExecuteOptions_ALL, + Workload: querypb.ExecuteOptions_Workload(mysqlDefaultWorkload), } + if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { - session.Options.ClientFoundRows = true + options.ClientFoundRows = true + } + + var err error + session, err = ph.proxy.NewSession(options) + if err != nil { + log.Errorf("error creating new session for %s: %v", c.GetRawConn().RemoteAddr().String(), err) + } + + if session != nil { + c.ClientData = session } - c.ClientData = session } + return session } var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener var sigChan chan os.Signal -var vtgateHandle *proxyHandler +var proxyHandle *proxyHandler // initTLSConfig inits tls config for the given mysql listener func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool, mysqlMinTLSVersion uint16) error { @@ -426,10 +423,10 @@ func initMySQLProtocol() { // Create a Listener. var err error - vtgateHandle = newProxyHandler(vtGateProxy) + proxyHandle = newProxyHandler(vtGateProxy) if *mysqlServerPort >= 0 { log.Infof("Mysql Server listening on Port %d", *mysqlServerPort) - mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, vtgateHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlProxyProtocol) + mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, proxyHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlProxyProtocol) if err != nil { log.Exitf("mysql.NewListener failed: %v", err) } @@ -458,7 +455,7 @@ func initMySQLProtocol() { // Let's create this unix socket with permissions to all users. In this way, // clients can connect to vtgate mysql server without being vtgate user oldMask := syscall.Umask(000) - mysqlUnixListener, err = newMysqlUnixSocket(*mysqlServerSocketPath, authServer, vtgateHandle) + mysqlUnixListener, err = newMysqlUnixSocket(*mysqlServerSocketPath, authServer, proxyHandle) _ = syscall.Umask(oldMask) if err != nil { log.Exitf("mysql.NewListener failed: %v", err) @@ -531,30 +528,35 @@ func shutdownMysqlProtocolAndDrain() { func rollbackAtShutdown() { defer log.Flush() - // Close all open connections. If they're waiting for reads, this will cause - // them to error out, which will automatically rollback open transactions. - func() { - if vtgateHandle != nil { - vtgateHandle.mu.Lock() - defer vtgateHandle.mu.Unlock() - for c := range vtgateHandle.connections { - if c != nil { - log.Infof("Rolling back transactions associated with connection ID: %v", c.ConnectionID) - c.Close() + // XXX/demmer figure out numConnections and rollback + /* + + // Close all open connections. If they're waiting for reads, this will cause + // them to error out, which will automatically rollback open transactions. + func() { + if proxyHandle != nil { + proxyHandle.mu.Lock() + defer proxyHandle.mu.Unlock() + for c := range proxyHandle.connections { + if c != nil { + log.Infof("Rolling back transactions associated with connection ID: %v", c.ConnectionID) + c.Close() + } } } - } - }() - - // If vtgate is instead busy executing a query, the number of open conns - // will be non-zero. Give another second for those queries to finish. - for i := 0; i < 100; i++ { - if vtgateHandle.numConnections() == 0 { - log.Infof("All connections have been rolled back.") - return - } - time.Sleep(10 * time.Millisecond) - } + }() + + + // If vtgate is instead busy executing a query, the number of open conns + // will be non-zero. Give another second for those queries to finish. + for i := 0; i < 100; i++ { + if proxyHandle.numConnections() == 0 { + log.Infof("All connections have been rolled back.") + return + } + time.Sleep(10 * time.Millisecond) + } + */ log.Errorf("All connections did not go idle. Shutting down anyway.") } diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index 691bf447fab..c56fcc4e429 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -21,16 +21,24 @@ package vtgateproxy import ( "context" "flag" + "time" + "google.golang.org/grpc" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/grpcclient" + "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/schema" "vitess.io/vitess/go/vt/vterrors" + _ "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" + "vitess.io/vitess/go/vt/vtgate/vtgateconn" ) var ( + target = flag.String("target", "", "vtgate host:port target used to dial the GRPC connection") + dialTimeout = flag.Duration("dial_timeout", 5*time.Second, "dialer timeout for the GRPC connection") + defaultDDLStrategy = flag.String("ddl_strategy", string(schema.DDLStrategyDirect), "Set default strategy for DDL statements. Override with @@ddl_strategy session variable") sysVarSetEnabled = flag.Bool("enable_system_settings", true, "This will enable the system settings to be changed per session at the database connection level") @@ -38,12 +46,36 @@ var ( ) type VTGateProxy struct { + conn *vtgateconn.VTGateConn +} + +func (proxy *VTGateProxy) connect(ctx context.Context) error { + grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) { + return append(opts, grpc.WithBlock()), nil + }) + + conn, err := vtgateconn.DialProtocol(ctx, "grpc", *target) + if err != nil { + return err + } + + proxy.conn = conn + return nil +} + +func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions) (*vtgateconn.VTGateSession, error) { + if proxy.conn == nil { + return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected") + } + + // XXX/demmer handle schemaName? + return proxy.conn.Session("", options), nil } // CloseSession closes the session, rolling back any implicit transactions. This has the // same effect as if a "rollback" statement was executed, but does not affect the query // statistics. -func (proxy *VTGateProxy) CloseSession(ctx context.Context, session *vtgatepb.Session) error { +func (proxy *VTGateProxy) CloseSession(ctx context.Context, session *vtgateconn.VTGateSession) error { return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "not implemented") } @@ -53,16 +85,36 @@ func (proxy *VTGateProxy) ResolveTransaction(ctx context.Context, dtid string) e } // Prepare supports non-streaming prepare statement query with multi shards -func (proxy *VTGateProxy) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (newSession *vtgatepb.Session, fld []*querypb.Field, err error) { +func (proxy *VTGateProxy) Prepare(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable) (newsession *vtgateconn.VTGateSession, fld []*querypb.Field, err error) { return nil, nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "not implemented") } -func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (newSession *vtgatepb.Session, qr *sqltypes.Result, err error) { - return nil, nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "not implemented") +func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable) (qr *sqltypes.Result, err error) { + log.Infof("Execute %s", sql) + + if proxy.conn == nil { + return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected") + } + + return session.Execute(ctx, sql, bindVariables) } -func (proxy *VTGateProxy) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { +func (proxy *VTGateProxy) StreamExecute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "not implemented") } -func Init() {} +func Init() error { + vtGateProxy = &VTGateProxy{} + + // XXX maybe add connect timeout? + ctx, cancel := context.WithTimeout(context.Background(), *dialTimeout) + defer cancel() + err := vtGateProxy.connect(ctx) + if err != nil { + log.Fatalf("error connecting to vtgate: %v", err) + return err + } + log.Infof("Connected to VTGate at %s", *target) + + return nil +}