diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 0eab81e950c..d17d515def6 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -129,6 +129,11 @@ type Conn struct { // It is set during the initial handshake. UserData Getter + // ConnectionAttributes stores attributes set in the connection phase when + // attributes from the client are sent. This is arbitrary key/value pairs + // sent by the client. + ConnectionAttributes ConnectionAttributesMap + bufferedReader *bufio.Reader flushTimer *time.Timer header [packetHeaderSize]byte diff --git a/go/mysql/constants.go b/go/mysql/constants.go index 415af39e761..1474f47d8e1 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -38,6 +38,10 @@ const ( // implemented authentication methods. type AuthMethodDescription string +// Map of client key/value pairs sent by the client during +// the connection phase +type ConnectionAttributesMap map[string]string + // Supported auth forms. const ( // MysqlNativePassword uses a salt and transmits a hash on the wire. diff --git a/go/mysql/server.go b/go/mysql/server.go index f59598b90f4..7d678d01bdb 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -354,11 +354,12 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } return } - user, clientAuthMethod, clientAuthResponse, err := l.parseClientHandshakePacket(c, true, response) + user, clientAuthMethod, clientAuthResponse, clientAttributes, err := l.parseClientHandshakePacket(c, true, response) if err != nil { log.Errorf("Cannot parse client handshake response from %s: %v", c, err) return } + c.ConnectionAttributes = clientAttributes c.recycleReadPacket() @@ -371,11 +372,12 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } // Returns copies of the data, so we can recycle the buffer. - user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response) + user, clientAuthMethod, clientAuthResponse, clientAttributes, err = l.parseClientHandshakePacket(c, false, response) if err != nil { log.Errorf("Cannot parse post-SSL client handshake response from %s: %v", c, err) return } + c.ConnectionAttributes = clientAttributes c.recycleReadPacket() if con, ok := c.conn.(*tls.Conn); ok { @@ -636,18 +638,18 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en } // parseClientHandshakePacket parses the handshake sent by the client. -// Returns the username, auth method, auth data, error. +// Returns the username, auth method, auth data, connection attributes, error. // The original data is not pointed at, and can be freed. -func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, error) { +func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, ConnectionAttributesMap, error) { pos := 0 // Client flags, 4 bytes. clientFlags, pos, ok := readUint32(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags") } if clientFlags&CapabilityClientProtocol41 == 0 { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1") } // Remember a subset of the capabilities, so we can use them @@ -666,13 +668,13 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by // See doc.go for more information. _, pos, ok = readUint32(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize") } // Character set. Need to handle it. characterSet, pos, ok := readByte(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet") } c.CharacterSet = collations.ID(characterSet) @@ -686,13 +688,13 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by c.conn = conn c.bufferedReader.Reset(conn) c.Capabilities |= CapabilityClientSSL - return "", "", nil, nil + return "", "", nil, nil, nil } // username username, pos, ok := readNullString(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username") } // auth-response can have three forms. @@ -701,29 +703,29 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by var l uint64 l, pos, ok = readLenEncInt(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length") } authResponse, pos, ok = readBytesCopy(data, pos, int(l)) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") } } else if clientFlags&CapabilityClientSecureConnection != 0 { var l byte l, pos, ok = readByte(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length") } authResponse, pos, ok = readBytesCopy(data, pos, int(l)) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") } } else { a := "" a, pos, ok = readNullString(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response") } authResponse = []byte(a) } @@ -733,7 +735,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by dbname := "" dbname, pos, ok = readNullString(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname") } c.schemaName = dbname } @@ -744,7 +746,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by var authMethodStr string authMethodStr, pos, ok = readNullString(data, pos) if !ok { - return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod") + return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod") } // The JDBC driver sometimes sends an empty string as the auth method when it wants to use mysql_native_password if authMethodStr != "" { @@ -753,16 +755,20 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by } // Decode connection attributes send by the client + var clientAttributes map[string]string if clientFlags&CapabilityClientConnAttr != 0 { - if _, _, err := parseConnAttrs(data, pos); err != nil { + ca, _, err := parseConnAttrs(data, pos) + if err != nil { log.Warningf("Decode connection attributes send by the client: %v", err) } + + clientAttributes = ca } - return username, AuthMethodDescription(authMethod), authResponse, nil + return username, AuthMethodDescription(authMethod), authResponse, clientAttributes, nil } -func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) { +func parseConnAttrs(data []byte, pos int) (ConnectionAttributesMap, int, error) { var attrLen uint64 attrLen, pos, ok := readLenEncInt(data, pos) diff --git a/go/vt/vtgateproxy/mysql_server.go b/go/vt/vtgateproxy/mysql_server.go index b0466b700e5..b4877942ba1 100644 --- a/go/vt/vtgateproxy/mysql_server.go +++ b/go/vt/vtgateproxy/mysql_server.go @@ -343,7 +343,7 @@ func (ph *proxyHandler) session(c *mysql.Conn) *vtgateconn.VTGateSession { } var err error - session, err = ph.proxy.NewSession(options) + session, err = ph.proxy.NewSession(options, c.ConnectionAttributes) if err != nil { log.Errorf("error creating new session for %s: %v", c.GetRawConn().RemoteAddr().String(), err) } diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index 1d16f360bd4..d1a336566de 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -21,6 +21,7 @@ package vtgateproxy import ( "context" "flag" + "fmt" "io" "time" @@ -68,11 +69,16 @@ func (proxy *VTGateProxy) connect(ctx context.Context) error { return nil } -func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions) (*vtgateconn.VTGateSession, error) { +func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions, connectionAttributes map[string]string) (*vtgateconn.VTGateSession, error) { if proxy.conn == nil { return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected") } + target, ok := connectionAttributes["target"] + if ok { + fmt.Printf("Creating new session from upstream provided target string: %v\n", target) + } + // XXX/demmer handle schemaName? return proxy.conn.Session("", options), nil } @@ -95,8 +101,6 @@ func (proxy *VTGateProxy) Prepare(ctx context.Context, session *vtgateconn.VTGat } func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable) (qr *sqltypes.Result, err error) { - log.Infof("Execute %s", sql) - if proxy.conn == nil { return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected") }