From 9257cbd46e745df8459fb815d780890e7d7d69e0 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Mon, 17 Jun 2024 11:20:29 +0200 Subject: [PATCH] ZDM-71: Introduce protocol negotiation --- integration-tests/connect_test.go | 35 +++++++++++++ integration-tests/setup/testcluster.go | 1 + integration-tests/utils/testutils.go | 9 +++- proxy/pkg/config/config.go | 1 + proxy/pkg/zdmproxy/controlconn.go | 70 ++++++++++++++++++++------ proxy/pkg/zdmproxy/cqlconn.go | 8 +-- 6 files changed, 105 insertions(+), 19 deletions(-) diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index b7df5f71..77fb11bf 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -8,6 +8,7 @@ import ( "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/datastax/zdm-proxy/integration-tests/client" + "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/integration-tests/utils" "github.com/datastax/zdm-proxy/proxy/pkg/config" @@ -45,6 +46,40 @@ func TestGoCqlConnect(t *testing.T) { require.Equal(t, "fake", iter.Columns()[0].Name) } +func TestProtocolVersionNegotiation(t *testing.T) { + testCassandraVersion := env.CassandraVersion + env.CassandraVersion = "2.1" // downgrade C* version for protocol negotiation test + defer func() { + env.CassandraVersion = testCassandraVersion + }() + c := setup.NewTestConfig("", "") + c.ProtocolVersion = 4 // configure unsupported protocol version + testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) + require.Nil(t, err) + defer testSetup.Cleanup() + + // Connect to proxy as a "client" + proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) + + if err != nil { + t.Log("Unable to connect to proxy session.") + t.Fatal(err) + } + defer proxy.Close() + + iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter() + result, err := iter.SliceMap() + + if err != nil { + t.Fatal("query failed:", err) + } + + require.Equal(t, 0, len(result)) + + // simulacron generates fake response metadata when queries aren't primed + require.Equal(t, "fake", iter.Columns()[0].Name) +} + func TestMaxClientsThreshold(t *testing.T) { maxClients := 10 goCqlConnectionsPerHost := 1 diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index 1eb60144..dac16ac0 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -452,6 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config { conf.ReadMode = config.ReadModePrimaryOnly conf.SystemQueriesMode = config.SystemQueriesModeOrigin conf.AsyncHandshakeTimeoutMs = 4000 + conf.ProtocolVersion = 3 conf.ProxyRequestTimeoutMs = 10000 diff --git a/integration-tests/utils/testutils.go b/integration-tests/utils/testutils.go index 2c050ecd..e0ca5edd 100644 --- a/integration-tests/utils/testutils.go +++ b/integration-tests/utils/testutils.go @@ -116,9 +116,9 @@ func CheckMetricsEndpointResult(httpAddr string, success bool) error { return nil } -// ConnectToCluster is used to connect to source and destination clusters -func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) { +func ConnectToClusterUsingVersion(hostname string, username string, password string, port int, protoVersion int) (*gocql.Session, error) { cluster := NewCluster(hostname, username, password, port) + cluster.ProtoVersion = protoVersion session, err := cluster.CreateSession() log.Debugf("Connection established with Cluster: %s:%d", cluster.Hosts[0], cluster.Port) if err != nil { @@ -127,6 +127,11 @@ func ConnectToCluster(hostname string, username string, password string, port in return session, nil } +// ConnectToCluster is used to connect to source and destination clusters +func ConnectToCluster(hostname string, username string, password string, port int) (*gocql.Session, error) { + return ConnectToClusterUsingVersion(hostname, username, password, port, 4) +} + // NewCluster initializes a ClusterConfig object with common settings func NewCluster(hostname string, username string, password string, port int) *gocql.ClusterConfig { cluster := gocql.NewCluster(hostname) diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index d5cc5c67..6e6c4027 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -21,6 +21,7 @@ type Config struct { ReplaceCqlFunctions bool `default:"false" split_words:"true"` AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` LogLevel string `default:"INFO" split_words:"true"` + ProtocolVersion uint `default:"3" split_words:"true"` // Proxy Topology (also known as system.peers "virtualization") bucket diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index e32bc967..e99f683f 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -54,11 +54,11 @@ type ControlConn struct { protocolEventSubscribers map[ProtocolEventObserver]interface{} authEnabled *atomic.Value metricsHandler *metrics.MetricHandler + protocolVersion primitive.ProtocolVersion } const ProxyVirtualRack = "rack0" const ProxyVirtualPartitioner = "org.apache.cassandra.dht.Murmur3Partitioner" -const ccProtocolVersion = primitive.ProtocolVersion3 const ccWriteTimeout = 5 * time.Second const ccReadTimeout = 10 * time.Second @@ -320,15 +320,9 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( currentIndex := (firstEndpointIndex + i) % len(endpoints) endpoint = endpoints[currentIndex] - tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false) - if err != nil { - log.Warnf("Failed to open control connection to %v using endpoint %v: %v", - cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err) - continue - } - newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf) - err = newConn.InitializeContext(ccProtocolVersion, ctx) + newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ProtocolVersion, ctx) + if err == nil { newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) { switch f.Body.Message.(type) { @@ -355,9 +349,11 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( log.Warnf("Error while initializing a new cql connection for the control connection of %v: %v", cc.connConfig.GetClusterType(), err) } - err2 := newConn.Close() - if err2 != nil { - log.Errorf("Failed to close cql connection: %v", err2) + if newConn != nil { + err2 := newConn.Close() + if err2 != nil { + log.Errorf("Failed to close cql connection: %v", err2) + } } continue @@ -372,6 +368,52 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( return conn, endpoint } +func (cc *ControlConn) connAndNegotiateProtoVer(endpoint Endpoint, initialProtoVer uint, ctx context.Context) (CqlConnection, error) { + protoVer := primitive.ProtocolVersion(initialProtoVer) + for { + tcpConn, _, err := openConnection(cc.connConfig, endpoint, ctx, false) + if err != nil { + log.Warnf("Failed to open control connection to %v using endpoint %v: %v", + cc.connConfig.GetClusterType(), endpoint.GetEndpointIdentifier(), err) + return nil, err + } + newConn := NewCqlConnection(tcpConn, cc.username, cc.password, ccReadTimeout, ccWriteTimeout, cc.conf) + err = newConn.InitializeContext(protoVer, ctx) + if err != nil && strings.Contains(err.Error(), "Invalid or unsupported protocol version") { + // unsupported protocol version + // protocol renegotiation requires opening a new TCP connection + err2 := newConn.Close() + if err2 != nil { + log.Errorf("Failed to close cql connection: %v", err2) + } + protoVer = downgradeProtocol(protoVer) + log.Infof("Downgrading protocol version: %v", protoVer) + if protoVer == 0 { + // we cannot downgrade anymore + return nil, err + } + continue // retry lower protocol version + } else { + cc.protocolVersion = protoVer + return newConn, err // we may have successfully established connection or faced other error + } + } +} + +func downgradeProtocol(version primitive.ProtocolVersion) primitive.ProtocolVersion { + switch version { + case primitive.ProtocolVersionDse2: + return primitive.ProtocolVersionDse1 + case primitive.ProtocolVersionDse1: + return primitive.ProtocolVersion4 + case primitive.ProtocolVersion4: + return primitive.ProtocolVersion3 + case primitive.ProtocolVersion3: + return primitive.ProtocolVersion2 + } + return 0 +} + func (cc *ControlConn) Close() { cc.cqlConnLock.Lock() conn := cc.cqlConn @@ -387,7 +429,7 @@ func (cc *ControlConn) Close() { } func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([]*Host, error) { - localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx) + localQueryResult, err := conn.Query("SELECT * FROM system.local", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.local table: %w", err) } @@ -410,7 +452,7 @@ func (cc *ControlConn) RefreshHosts(conn CqlConnection, ctx context.Context) ([] } } - peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), ccProtocolVersion, ctx) + peersQuery, err := conn.Query("SELECT * FROM system.peers", GetDefaultGenericTypeCodec(), cc.protocolVersion, ctx) if err != nil { return nil, fmt.Errorf("could not fetch information from system.peers table: %w", err) } diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 041894fe..c8a6e43d 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -59,6 +59,7 @@ type cqlConn struct { eventHandlerLock *sync.Mutex authEnabled bool frameProcessor FrameProcessor + protocolVersion primitive.ProtocolVersion } var ( @@ -237,6 +238,7 @@ func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx conte return fmt.Errorf("failed to perform handshake: %w", err) } + c.protocolVersion = version c.initialized = true c.authEnabled = authEnabled return nil @@ -375,7 +377,7 @@ func (c *cqlConn) Query( }, } - queryFrame := frame.NewFrame(ccProtocolVersion, -1, queryMsg) + queryFrame := frame.NewFrame(c.protocolVersion, -1, queryMsg) var rowSet *ParsedRowSet for { localResponse, err := c.SendAndReceive(queryFrame, ctx) @@ -429,7 +431,7 @@ func (c *cqlConn) Query( } func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Message, error) { - queryFrame := frame.NewFrame(ccProtocolVersion, -1, msg) + queryFrame := frame.NewFrame(c.protocolVersion, -1, msg) localResponse, err := c.SendAndReceive(queryFrame, ctx) if err != nil { return nil, err @@ -440,7 +442,7 @@ func (c *cqlConn) Execute(msg message.Message, ctx context.Context) (message.Mes func (c *cqlConn) SendHeartbeat(ctx context.Context) error { optionsMsg := &message.Options{} - heartBeatFrame := frame.NewFrame(ccProtocolVersion, -1, optionsMsg) + heartBeatFrame := frame.NewFrame(c.protocolVersion, -1, optionsMsg) response, err := c.SendAndReceive(heartBeatFrame, ctx) if err != nil {