diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index c634cbf..a066eed 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -47,31 +47,64 @@ func TestGoCqlConnect(t *testing.T) { } func TestProtocolVersionNegotiation(t *testing.T) { - c := setup.NewTestConfig("", "") - c.ControlConnMaxProtocolVersion = 4 // configure unsupported protocol version - testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, &simulacron.ClusterVersion{"2.1", "2.1"}) - require.Nil(t, err) - defer testSetup.Cleanup() + tests := []struct { + name string + clusterVersion string + controlConnMaxProtocolVersion string + negotiatedProtocolVersion primitive.ProtocolVersion + }{ + { + name: "Cluster2.1_MaxCCProtoVer4_NegotiatedProtoVer3", + clusterVersion: "2.1", + controlConnMaxProtocolVersion: "4", + negotiatedProtocolVersion: primitive.ProtocolVersion3, // protocol downgraded to V3, V4 is not supported + }, + { + name: "Cluster3.0_MaxCCProtoVer4_NegotiatedProtoVer4", + clusterVersion: "3.0", + controlConnMaxProtocolVersion: "4", + negotiatedProtocolVersion: primitive.ProtocolVersion4, + }, + { + name: "Cluster4.0_MaxCCProtoVer4_NegotiatedProtoVer4", + clusterVersion: "4.0", + controlConnMaxProtocolVersion: "4", + negotiatedProtocolVersion: primitive.ProtocolVersion4, + }, + } - // Connect to proxy as a "client" - proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := setup.NewTestConfig("", "") + c.ControlConnMaxProtocolVersion = tt.controlConnMaxProtocolVersion + testSetup, err := setup.NewSimulacronTestSetupWithSessionAndNodesAndConfig(t, true, false, 1, c, + &simulacron.ClusterVersion{tt.clusterVersion, tt.clusterVersion}) + require.Nil(t, err) + defer testSetup.Cleanup() - if err != nil { - t.Fatal("Unable to connect to proxy session.") - } - defer proxy.Close() + // Connect to proxy as a "client" + proxy, err := utils.ConnectToClusterUsingVersion("127.0.0.1", "", "", 14002, 3) - iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter() - result, err := iter.SliceMap() + if err != nil { + t.Fatal("Unable to connect to proxy session.") + } + defer proxy.Close() - if err != nil { - t.Fatal("query failed:", err) - } + cqlConn, _ := testSetup.Proxy.GetOriginControlConn().GetConnAndContactPoint() + negotiatedProto := cqlConn.GetProtocolVersion().Load().(primitive.ProtocolVersion) - require.Equal(t, 0, len(result)) + require.Equal(t, tt.negotiatedProtocolVersion, negotiatedProto) - // simulacron generates fake response metadata when queries aren't primed - require.Equal(t, "fake", iter.Columns()[0].Name) + iter := proxy.Query("SELECT * FROM fakeks.faketb").Iter() + result, err := iter.SliceMap() + + if err != nil { + t.Fatal("query failed:", err) + } + + require.Equal(t, 0, len(result)) + }) + } } func TestMaxClientsThreshold(t *testing.T) { diff --git a/integration-tests/setup/testcluster.go b/integration-tests/setup/testcluster.go index 7791cab..5261321 100644 --- a/integration-tests/setup/testcluster.go +++ b/integration-tests/setup/testcluster.go @@ -452,7 +452,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config { conf.ReadMode = config.ReadModePrimaryOnly conf.SystemQueriesMode = config.SystemQueriesModeOrigin conf.AsyncHandshakeTimeoutMs = 4000 - conf.ControlConnMaxProtocolVersion = 3 + conf.ControlConnMaxProtocolVersion = "3" conf.ProxyRequestTimeoutMs = 10000 diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index c48b5e6..4df2356 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -21,7 +21,7 @@ type Config struct { ReplaceCqlFunctions bool `default:"false" split_words:"true"` AsyncHandshakeTimeoutMs int `default:"4000" split_words:"true"` LogLevel string `default:"INFO" split_words:"true"` - ControlConnMaxProtocolVersion uint `default:"3" split_words:"true"` + ControlConnMaxProtocolVersion string `default:"3" split_words:"true"` // Numeric Cassandra OSS protocol version or Dse1 / Dse2 // Proxy Topology (also known as system.peers "virtualization") bucket @@ -283,6 +283,11 @@ func (c *Config) Validate() error { return err } + _, err = c.ParseControlConnMaxProtocolVersion() + if err != nil { + return err + } + return nil } @@ -337,6 +342,24 @@ func (c *Config) ParseReadMode() (common.ReadMode, error) { } } +func (c *Config) ParseControlConnMaxProtocolVersion() (uint, error) { + switch c.ControlConnMaxProtocolVersion { + case "Dse2": + return 0b_1_000010, nil + case "Dse1": + return 0b_1_000001, nil + } + ver, err := strconv.ParseUint(c.ControlConnMaxProtocolVersion, 10, 32) + if err != nil { + return 0, fmt.Errorf("could not parse control connection max protocol version, valid values are "+ + "2, 3, 4, Dse1, Dse2; original err: %w", err) + } + if ver < 2 || ver > 4 { + return 0, fmt.Errorf("invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2") + } + return uint(ver), nil +} + func (c *Config) ParseLogLevel() (log.Level, error) { level, err := log.ParseLevel(strings.TrimSpace(c.LogLevel)) if err != nil { diff --git a/proxy/pkg/config/config_test.go b/proxy/pkg/config/config_test.go index 5265131..6da7c43 100644 --- a/proxy/pkg/config/config_test.go +++ b/proxy/pkg/config/config_test.go @@ -93,3 +93,87 @@ func TestTargetConfig_WithHostnameButWithoutPort(t *testing.T) { require.Nil(t, err) require.Equal(t, 9042, c.TargetPort) } + +func TestTargetConfig_ParsingControlConnMaxProtocolVersion(t *testing.T) { + defer clearAllEnvVars() + + // general setup + clearAllEnvVars() + setOriginCredentialsEnvVars() + setTargetCredentialsEnvVars() + setOriginContactPointsAndPortEnvVars() + + // test-specific setup + setTargetContactPointsAndPortEnvVars() + + conf, _ := New().ParseEnvVars() + + tests := []struct { + name string + controlConnMaxProtocolVersion string + parsedProtocolVersion uint + errorMessage string + }{ + { + name: "ParsedV2", + controlConnMaxProtocolVersion: "2", + parsedProtocolVersion: 2, + errorMessage: "", + }, + { + name: "ParsedV3", + controlConnMaxProtocolVersion: "3", + parsedProtocolVersion: 3, + errorMessage: "", + }, + { + name: "ParsedV4", + controlConnMaxProtocolVersion: "4", + parsedProtocolVersion: 4, + errorMessage: "", + }, + { + name: "ParsedDse1", + controlConnMaxProtocolVersion: "Dse1", + parsedProtocolVersion: 65, + errorMessage: "", + }, + { + name: "ParsedDse2", + controlConnMaxProtocolVersion: "Dse2", + parsedProtocolVersion: 66, + errorMessage: "", + }, + { + name: "UnsupportedCassandraV5", + controlConnMaxProtocolVersion: "5", + parsedProtocolVersion: 0, + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + }, + { + name: "UnsupportedCassandraV1", + controlConnMaxProtocolVersion: "1", + parsedProtocolVersion: 0, + errorMessage: "invalid control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + }, + { + name: "InvalidValue", + controlConnMaxProtocolVersion: "Dsev123", + parsedProtocolVersion: 0, + errorMessage: "could not parse control connection max protocol version, valid values are 2, 3, 4, Dse1, Dse2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf.ControlConnMaxProtocolVersion = tt.controlConnMaxProtocolVersion + ver, err := conf.ParseControlConnMaxProtocolVersion() + if ver == 0 { + require.NotNil(t, err) + require.Contains(t, err.Error(), tt.errorMessage) + } else { + require.Equal(t, tt.parsedProtocolVersion, ver) + } + }) + } +} diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 7fea4d5..b247243 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -125,7 +125,7 @@ func (cc *ControlConn) Start(wg *sync.WaitGroup, ctx context.Context) error { log.Infof("Received topology event from %v, refreshing topology.", cc.connConfig.GetClusterType()) - conn, _ := cc.getConnAndContactPoint() + conn, _ := cc.GetConnAndContactPoint() if conn == nil { log.Debugf("Topology refresh scheduled but the control connection isn't open. " + "Falling back to the connection where the event was received.") @@ -162,7 +162,7 @@ func (cc *ControlConn) Start(wg *sync.WaitGroup, ctx context.Context) error { cc.Close() } - conn, _ := cc.getConnAndContactPoint() + conn, _ := cc.GetConnAndContactPoint() if conn == nil { useContactPointsOnly := false if !lastOpenSuccessful { @@ -251,7 +251,7 @@ func (cc *ControlConn) ReadFailureCounter() int { } func (cc *ControlConn) Open(contactPointsOnly bool, ctx context.Context) (CqlConnection, error) { - oldConn, _ := cc.getConnAndContactPoint() + oldConn, _ := cc.GetConnAndContactPoint() if oldConn != nil { cc.Close() oldConn = nil @@ -321,7 +321,8 @@ func (cc *ControlConn) openInternal(endpoints []Endpoint, ctx context.Context) ( currentIndex := (firstEndpointIndex + i) % len(endpoints) endpoint = endpoints[currentIndex] - newConn, err := cc.connAndNegotiateProtoVer(endpoint, cc.conf.ControlConnMaxProtocolVersion, ctx) + maxProtoVer, _ := cc.conf.ParseControlConnMaxProtocolVersion() + newConn, err := cc.connAndNegotiateProtoVer(endpoint, maxProtoVer, ctx) if err == nil { newConn.SetEventHandler(func(f *frame.Frame, c CqlConnection) { @@ -678,7 +679,7 @@ func (cc *ControlConn) setConn(oldConn CqlConnection, newConn CqlConnection, new return cc.cqlConn, cc.currentContactPoint } -func (cc *ControlConn) getConnAndContactPoint() (CqlConnection, Endpoint) { +func (cc *ControlConn) GetConnAndContactPoint() (CqlConnection, Endpoint) { cc.cqlConnLock.Lock() conn := cc.cqlConn contactPoint := cc.currentContactPoint diff --git a/proxy/pkg/zdmproxy/cqlconn.go b/proxy/pkg/zdmproxy/cqlconn.go index 0014321..8749516 100644 --- a/proxy/pkg/zdmproxy/cqlconn.go +++ b/proxy/pkg/zdmproxy/cqlconn.go @@ -38,6 +38,7 @@ type CqlConnection interface { SetEventHandler(eventHandler func(f *frame.Frame, conn CqlConnection)) SubscribeToProtocolEvents(ctx context.Context, eventTypes []primitive.EventType) error IsAuthEnabled() (bool, error) + GetProtocolVersion() *atomic.Value } // Not thread safe @@ -98,6 +99,7 @@ func NewCqlConnection( eventHandlerLock: &sync.Mutex{}, authEnabled: true, frameProcessor: NewStreamIdProcessor(NewInternalStreamIdMapper(conf.ProxyMaxStreamIds, nil)), + protocolVersion: &atomic.Value{}, } cqlConn.StartRequestLoop() cqlConn.StartResponseLoop() @@ -233,13 +235,16 @@ func (c *cqlConn) IsAuthEnabled() (bool, error) { return c.authEnabled, nil } +func (c *cqlConn) GetProtocolVersion() *atomic.Value { + return c.protocolVersion +} + func (c *cqlConn) InitializeContext(version primitive.ProtocolVersion, ctx context.Context) error { authEnabled, err := c.PerformHandshake(version, ctx) if err != nil { return fmt.Errorf("failed to perform handshake: %w", err) } - c.protocolVersion = &atomic.Value{} c.protocolVersion.Store(version) c.initialized = true c.authEnabled = authEnabled