Skip to content

Commit

Permalink
Stream ID verification
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-antoniak committed Jul 19, 2024
1 parent 21eebb6 commit 5ea7eef
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 74 deletions.
28 changes: 18 additions & 10 deletions integration-tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,23 @@ func TestMaxClientsThreshold(t *testing.T) {

func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) {
tests := []struct {
name string
requestVersion primitive.ProtocolVersion
expectedVersion primitive.ProtocolVersion
errExpected string
name string
requestVersion primitive.ProtocolVersion
negotiatedVersion string
expectedVersion primitive.ProtocolVersion
errExpected string
}{
{
"request v5, response v4",
primitive.ProtocolVersion5,
"4",
primitive.ProtocolVersion4,
"Invalid or unsupported protocol version (5)",
},
{
"request v1, response v4",
primitive.ProtocolVersion(0x1),
"4",
primitive.ProtocolVersion4,
"Invalid or unsupported protocol version (1)",
},
Expand All @@ -189,6 +192,7 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) {
defer zerolog.SetGlobalLevel(oldZeroLogLevel)

cfg := setup.NewTestConfig("127.0.1.1", "127.0.1.2")
cfg.ControlConnMaxProtocolVersion = test.negotiatedVersion
cfg.LogLevel = "TRACE" // saw 1 test failure here once but logs didn't show enough info
testSetup, err := setup.NewCqlServerTestSetup(t, cfg, false, false, false)
require.Nil(t, err)
Expand Down Expand Up @@ -218,30 +222,34 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) {

func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) {
type test struct {
name string
requestVersion primitive.ProtocolVersion
returnedVersion primitive.ProtocolVersion
expectedVersion primitive.ProtocolVersion
errExpected string
name string
requestVersion primitive.ProtocolVersion
negotiatedVersion string
returnedVersion primitive.ProtocolVersion
expectedVersion primitive.ProtocolVersion
errExpected string
}
tests := []*test{
{
"DSE_V2 request, v5 returned, v4 expected",
primitive.ProtocolVersionDse2,
"4",
primitive.ProtocolVersion5,
primitive.ProtocolVersion4,
"Invalid or unsupported protocol version (5)",
},
{
"DSE_V2 request, v1 returned, v4 expected",
primitive.ProtocolVersionDse2,
"4",
primitive.ProtocolVersion(0x01),
primitive.ProtocolVersion4,
"Invalid or unsupported protocol version (1)",
},
}

runTestFunc := func(t *testing.T, test *test, cfg *config.Config) {
cfg.ControlConnMaxProtocolVersion = test.negotiatedVersion // simulate what version was negotiated on control connection
testSetup, err := setup.NewCqlServerTestSetup(t, cfg, false, false, false)
require.Nil(t, err)
defer testSetup.Cleanup()
Expand Down Expand Up @@ -299,7 +307,7 @@ func TestReturnedProtocolVersionUnsupportedByProxy(t *testing.T) {
}

func createFrameWithUnsupportedVersion(version primitive.ProtocolVersion, streamId int16, isResponse bool) ([]byte, error) {
mostSimilarVersion := primitive.ProtocolVersion4
mostSimilarVersion := version
if version > primitive.ProtocolVersionDse2 {
mostSimilarVersion = primitive.ProtocolVersionDse2
} else if version < primitive.ProtocolVersion2 {
Expand Down
47 changes: 47 additions & 0 deletions integration-tests/streamids_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,53 @@ func TestLimitStreamIdsGeneration(t *testing.T) {
}
}

func TestFailOnNegativeStreamIDsFromClient(t *testing.T) {
originAddress := "127.0.1.1"
targetAddress := "127.0.1.2"
originProtoVer := primitive.ProtocolVersion2
targetProtoVer := primitive.ProtocolVersion2
serverConf := setup.NewTestConfig(originAddress, targetAddress)
proxyConf := setup.NewTestConfig(originAddress, targetAddress)

queryInsert := &message.Query{
Query: "INSERT INTO test_ks.test(key, value) VALUES(1, '1')", // use INSERT to route request to both clusters
}

testSetup, err := setup.NewCqlServerTestSetup(t, serverConf, false, false, false)
require.Nil(t, err)
defer testSetup.Cleanup()

originRequestHandler := NewMaxStreamIdsRequestHandler("origin", "dc1", originAddress, 100)
targetRequestHandler := NewProtocolNegotiationRequestHandler("target", "dc1", targetAddress, []primitive.ProtocolVersion{targetProtoVer})

testSetup.Origin.CqlServer.RequestHandlers = []client.RequestHandler{
originRequestHandler.HandleRequest,
client.NewDriverConnectionInitializationHandler("origin", "dc1", func(_ string) {}),
}
testSetup.Target.CqlServer.RequestHandlers = []client.RequestHandler{
targetRequestHandler.HandleRequest,
client.NewDriverConnectionInitializationHandler("target", "dc1", func(_ string) {}),
}

err = testSetup.Start(nil, false, originProtoVer)
require.Nil(t, err)

proxyConf.ProxyMaxStreamIds = 100
proxy, err := setup.NewProxyInstanceWithConfig(proxyConf) // starts the proxy
if proxy != nil {
defer proxy.Shutdown()
}
require.Nil(t, err)

cqlConn, err := testSetup.Client.CqlClient.ConnectAndInit(context.Background(), originProtoVer, 0)
require.Nil(t, err)
defer cqlConn.Close()

response, _ := cqlConn.SendAndReceive(frame.NewFrame(originProtoVer, -1, queryInsert))
require.IsType(t, response.Body.Message, &message.ProtocolError{})
require.Equal(t, "negative stream id: -1", response.Body.Message.(*message.ProtocolError).ErrorMessage)
}

type MaxStreamIdsRequestHandler struct {
lock sync.Mutex
cluster string
Expand Down
18 changes: 10 additions & 8 deletions proxy/pkg/zdmproxy/clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type ClientConnector struct {
readScheduler *Scheduler

shutdownRequestCtx context.Context

minProtoVer primitive.ProtocolVersion
}

func NewClientConnector(
Expand All @@ -71,7 +73,8 @@ func NewClientConnector(
readScheduler *Scheduler,
writeScheduler *Scheduler,
shutdownRequestCtx context.Context,
clientHandlerShutdownRequestCancelFn context.CancelFunc) *ClientConnector {
clientHandlerShutdownRequestCancelFn context.CancelFunc,
minProtoVer primitive.ProtocolVersion) *ClientConnector {

return &ClientConnector{
connection: connection,
Expand All @@ -97,6 +100,7 @@ func NewClientConnector(
readScheduler: readScheduler,
shutdownRequestCtx: shutdownRequestCtx,
clientHandlerShutdownRequestCancelFn: clientHandlerShutdownRequestCancelFn,
minProtoVer: minProtoVer,
}
}

Expand Down Expand Up @@ -176,7 +180,7 @@ func (cc *ClientConnector) listenForRequests() {
for cc.clientHandlerContext.Err() == nil {
f, err := readRawFrame(bufferedReader, connectionAddr, cc.clientHandlerContext)

protocolErrResponseFrame, err, _ := checkProtocolError(f, err, protocolErrOccurred, ClientConnectorLogPrefix)
protocolErrResponseFrame, err, _ := checkProtocolError(f, cc.minProtoVer, err, protocolErrOccurred, ClientConnectorLogPrefix)
if err != nil {
handleConnectionError(
err, cc.clientHandlerContext, cc.clientHandlerCancelFunc, ClientConnectorLogPrefix, "reading", connectionAddr)
Expand Down Expand Up @@ -224,7 +228,7 @@ func (cc *ClientConnector) sendOverloadedToClient(request *frame.RawFrame) {
}
}

func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred bool, prefix string) (protocolErrResponse *frame.RawFrame, fatalErr error, errorCode int8) {
func checkProtocolError(f *frame.RawFrame, protoVer primitive.ProtocolVersion, connErr error, protocolErrorOccurred bool, prefix string) (protocolErrResponse *frame.RawFrame, fatalErr error, errorCode int8) {
var protocolErrMsg *message.ProtocolError
var streamId int16
var logMsg string
Expand All @@ -244,7 +248,7 @@ func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred
if !protocolErrorOccurred {
log.Debugf("[%v] %v Returning a protocol error to the client to force a downgrade: %v.", prefix, logMsg, protocolErrMsg)
}
rawProtocolErrResponse, err := generateProtocolErrorResponseFrame(streamId, protocolErrMsg)
rawProtocolErrResponse, err := generateProtocolErrorResponseFrame(streamId, protoVer, protocolErrMsg)
if err != nil {
return nil, fmt.Errorf("could not generate protocol error response raw frame (%v): %v", protocolErrMsg, err), -1
} else {
Expand All @@ -255,10 +259,8 @@ func checkProtocolError(f *frame.RawFrame, connErr error, protocolErrorOccurred
}
}

func generateProtocolErrorResponseFrame(streamId int16, protocolErrMsg *message.ProtocolError) (*frame.RawFrame, error) {
// ideally we would use the maximum version between the versions used by both control connections if
// control connections implemented protocol version negotiation
response := frame.NewFrame(primitive.ProtocolVersion4, streamId, protocolErrMsg)
func generateProtocolErrorResponseFrame(streamId int16, protoVer primitive.ProtocolVersion, protocolErrMsg *message.ProtocolError) (*frame.RawFrame, error) {
response := frame.NewFrame(protoVer, streamId, protocolErrMsg)
rawResponse, err := defaultCodec.ConvertToRawFrame(response)
if err != nil {
return nil, err
Expand Down
75 changes: 42 additions & 33 deletions proxy/pkg/zdmproxy/clienthandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
log "github.com/sirupsen/logrus"
"net"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -170,10 +171,16 @@ func NewClientHandler(
// Initialize stream id processors to manage the ids sent to the clusters
originCCProtoVer := originControlConn.cqlConn.GetProtocolVersion()
targetCCProtoVer := targetControlConn.cqlConn.GetProtocolVersion()
streamIds := maxStreamIds(originCCProtoVer, targetCCProtoVer, conf)
originFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeOrigin)
targetFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeTarget)
asyncFrameProcessor := newFrameProcessor(streamIds, nodeMetrics, ClusterConnectorTypeAsync)
// Calculate maximum number of stream IDs. Take the oldest protocol version negotiated between two clusters
// and apply limit defined in proxy configuration. If origin or target cluster are still running protocol V2,
// we will limit maximum number of stream IDs to 127 on both clusters. Logic is based on Java driver version 3.x.
// Java driver 3.x was the last one supporting protocol V2. It establishes control connection first, and then
// uses negotiated protocol version to configure maximum number of stream IDs on node connections. Driver does NOT
// change the number of stream IDs on per node basis. Maximum stream ID is calculated while creating stream ID mapper.
minimalProtoVer := minProtoVer(originCCProtoVer, targetCCProtoVer)
originFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeOrigin)
targetFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeTarget)
asyncFrameProcessor := newFrameProcessor(minimalProtoVer, conf, nodeMetrics, ClusterConnectorTypeAsync)

closeFrameProcessors := func() {
originFrameProcessor.Close()
Expand All @@ -200,7 +207,7 @@ func NewClientHandler(
originConnector, err := NewClusterConnector(
originCassandraConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg,
clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx,
false, nil, handshakeDone, originFrameProcessor)
false, nil, handshakeDone, originFrameProcessor, originCCProtoVer)
if err != nil {
clientHandlerCancelFunc()
return nil, err
Expand All @@ -209,7 +216,7 @@ func NewClientHandler(
targetConnector, err := NewClusterConnector(
targetCassandraConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg,
clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx,
false, nil, handshakeDone, targetFrameProcessor)
false, nil, handshakeDone, targetFrameProcessor, targetCCProtoVer)
if err != nil {
clientHandlerCancelFunc()
return nil, err
Expand All @@ -227,7 +234,7 @@ func NewClientHandler(
asyncConnector, err = NewClusterConnector(
asyncConnInfo, conf, psCache, nodeMetrics, localClientHandlerWg, clientHandlerRequestWg,
clientHandlerContext, clientHandlerCancelFunc, respChannel, readScheduler, writeScheduler, requestsDoneCtx,
true, asyncPendingRequests, handshakeDone, asyncFrameProcessor)
true, asyncPendingRequests, handshakeDone, asyncFrameProcessor, originCCProtoVer)
if err != nil {
log.Errorf("Could not create async cluster connector to %s, async requests will not be forwarded: %s", asyncConnInfo.connConfig.GetClusterType(), err.Error())
asyncConnector = nil
Expand Down Expand Up @@ -263,7 +270,8 @@ func NewClientHandler(
readScheduler,
writeScheduler,
clientHandlerShutdownRequestContext,
clientHandlerShutdownRequestCancelFn),
clientHandlerShutdownRequestCancelFn,
minProtoVer(originCCProtoVer, targetCCProtoVer)),

asyncConnector: asyncConnector,
originCassandraConnector: originConnector,
Expand Down Expand Up @@ -1487,7 +1495,7 @@ func (ch *ClientHandler) executeRequest(
f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin, common.ClusterTypeTarget)
sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest)
if sendErr != nil {
ch.clientConnector.sendOverloadedToClient(frameContext.frame)
ch.handleRequestSendFailure(sendErr, frameContext)
} else {
ch.targetCassandraConnector.sendRequestToCluster(targetRequest)
}
Expand All @@ -1496,15 +1504,15 @@ func (ch *ClientHandler) executeRequest(
f.Header.OpCode, f.Header.StreamId, common.ClusterTypeOrigin)
sendErr := ch.originCassandraConnector.sendRequestToCluster(originRequest)
if sendErr != nil {
ch.clientConnector.sendOverloadedToClient(frameContext.frame)
ch.handleRequestSendFailure(sendErr, frameContext)
}
ch.targetCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs)
case forwardToTarget:
log.Tracef("Forwarding request with opcode %v for stream %v to %v",
f.Header.OpCode, f.Header.StreamId, common.ClusterTypeTarget)
sendErr := ch.targetCassandraConnector.sendRequestToCluster(targetRequest)
if sendErr != nil {
ch.clientConnector.sendOverloadedToClient(frameContext.frame)
ch.handleRequestSendFailure(sendErr, frameContext)
}
ch.originCassandraConnector.sendHeartbeat(startupFrameVersion, ch.conf.HeartbeatIntervalMs)
case forwardToAsyncOnly:
Expand All @@ -1526,6 +1534,20 @@ func (ch *ClientHandler) executeRequest(
overallRequestStartTime, requestTimeout)
}

func (ch *ClientHandler) handleRequestSendFailure(err error, frameContext *frameDecodeContext) {
if strings.Contains(err.Error(), "no stream id available") {
ch.clientConnector.sendOverloadedToClient(frameContext.frame)
} else if strings.Contains(err.Error(), "negative stream id") {
responseMessage := &message.ProtocolError{ErrorMessage: err.Error()}
responseFrame, err := generateProtocolErrorResponseFrame(
frameContext.frame.Header.StreamId, frameContext.frame.Header.Version, responseMessage)
if err != nil {
log.Errorf("could not generate protocol error response raw frame (%v): %v", responseMessage, err)
}
ch.clientConnector.sendResponseToClient(responseFrame)
}
}

func (ch *ClientHandler) handleInterceptedRequest(
requestInfo RequestInfo, frameContext *frameDecodeContext, currentKeyspace string) (*frame.RawFrame, error) {

Expand Down Expand Up @@ -2243,7 +2265,8 @@ func GetNodeMetricsByClusterConnector(nodeMetrics *metrics.NodeMetrics, connecto
}
}

func newFrameProcessor(maxStreamIds int, nodeMetrics *metrics.NodeMetrics, connectorType ClusterConnectorType) FrameProcessor {
func newFrameProcessor(protoVer primitive.ProtocolVersion, config *config.Config, nodeMetrics *metrics.NodeMetrics,
connectorType ClusterConnectorType) FrameProcessor {
var streamIdsMetric metrics.Gauge
connectorMetrics, err := GetNodeMetricsByClusterConnector(nodeMetrics, connectorType)
if err != nil {
Expand All @@ -2254,30 +2277,16 @@ func newFrameProcessor(maxStreamIds int, nodeMetrics *metrics.NodeMetrics, conne
}
var mapper StreamIdMapper
if connectorType == ClusterConnectorTypeAsync {
mapper = NewInternalStreamIdMapper(maxStreamIds, streamIdsMetric)
mapper = NewInternalStreamIdMapper(protoVer, config, streamIdsMetric)
} else {
mapper = NewStreamIdMapper(maxStreamIds, streamIdsMetric)
mapper = NewStreamIdMapper(protoVer, config, streamIdsMetric)
}
return NewStreamIdProcessor(mapper)
}

// Calculate maximum number of stream IDs. Take the oldest protocol version negotiated between two clusters
// and apply limit defined in proxy configuration. If origin or target cluster are still running protocol V2,
// we will limit maximum number of stream IDs to 128 on both clusters. Logic is based on Java driver version 3.x.
// Java driver 3.x was the last one supporting protocol V2. It establishes control connection first, and then
// uses negotiated protocol version to configure maximum number of stream IDs on node connections. Driver does NOT
// change the number of stream IDs on per node basis.
func maxStreamIds(originProtoVer primitive.ProtocolVersion, targetProtoVer primitive.ProtocolVersion, conf *config.Config) int {
maxSupported := maxStreamIdsV3
protoVer := originProtoVer
if targetProtoVer < originProtoVer {
protoVer = targetProtoVer
}
if protoVer == primitive.ProtocolVersion2 {
maxSupported = maxStreamIdsV2
}
if maxSupported < conf.ProxyMaxStreamIds {
return maxSupported
}
return conf.ProxyMaxStreamIds
func minProtoVer(version1 primitive.ProtocolVersion, version2 primitive.ProtocolVersion) primitive.ProtocolVersion {
if version1 < version2 {
return version1
}
return version2
}
2 changes: 1 addition & 1 deletion proxy/pkg/zdmproxy/clienthandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestMaxStreamIds(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ids := maxStreamIds(tt.args.originProtoVer, tt.args.targetProtoVer, tt.args.config)
ids := maxStreamIds(minProtoVer(tt.args.originProtoVer, tt.args.targetProtoVer), tt.args.config)
require.Equal(t, tt.expectedMaxStreamIds, ids)
})
}
Expand Down
Loading

0 comments on commit 5ea7eef

Please sign in to comment.