diff --git a/go.mod b/go.mod index bf3f753..381a375 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/alecthomas/kong v0.2.17 github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec github.com/datastax/go-cassandra-native-protocol v0.0.0-20210604174339-4311e5d5654d + github.com/hashicorp/golang-lru v0.5.4 github.com/stretchr/testify v1.7.0 go.uber.org/atomic v1.8.0 // indirect go.uber.org/multierr v1.7.0 // indirect diff --git a/go.sum b/go.sum index c57145f..58e75fe 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,5 @@ github.com/alecthomas/kong v0.2.17 h1:URDISCI96MIgcIlQyoCAlhOmrSw6pZScBNkctg8r0W0= github.com/alecthomas/kong v0.2.17/go.mod h1:ka3VZ8GZNPXv9Ov+j4YNLkI8mTuhXyr/0ktSlqIydQQ= - github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec h1:EEyRvzmpEUZ+I8WmD5cw/vY8EqhambkOqy5iFr0908A= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -11,8 +10,9 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw= github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/pierrec/lz4/v4 v4.0.3 h1:vNQKSVZNYUEAvRY9FaUXAF1XPbSOHJtDTiP41kzDz2E= - github.com/pierrec/lz4/v4 v4.0.3/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -21,14 +21,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= - - github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= - go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.8.0 h1:CUhrE4N1rqSE6FM9ecihEjRkLQu8cDfgDyoOs83mEY4= go.uber.org/atomic v1.8.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= @@ -38,23 +35,16 @@ go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95a go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= - golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= - golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= - golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= - golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= - golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= - golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= - gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= \ No newline at end of file +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/proxy/codecs.go b/proxy/codecs.go index 7ee7073..8679b97 100644 --- a/proxy/codecs.go +++ b/proxy/codecs.go @@ -18,10 +18,11 @@ import ( "encoding/hex" "errors" "fmt" + "io" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" - "io" ) var codec = frame.NewRawCodec(&partialQueryCodec{}, &partialExecuteCodec{}) @@ -110,4 +111,4 @@ func (c *partialExecuteCodec) Decode(source io.Reader, _ primitive.ProtocolVersi func (c *partialExecuteCodec) GetOpCode() primitive.OpCode { return primitive.OpCodeExecute -} +} \ No newline at end of file diff --git a/proxy/proxy.go b/proxy/proxy.go index f6d6576..79476ca 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -33,6 +33,7 @@ import ( "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" + lru "github.com/hashicorp/golang-lru" "go.uber.org/zap" ) @@ -50,6 +51,9 @@ type Config struct { Logger *zap.Logger HeartBeatInterval time.Duration IdleTimeout time.Duration + // PreparedCache a cache that stores prepared queries. If not set it uses the default implementation with a max + // capacity of ~100MB. + PreparedCache proxycore.PreparedCache } type Proxy struct { @@ -59,8 +63,9 @@ type Proxy struct { listener *net.TCPListener cluster *proxycore.Cluster sessions sync.Map - sessMu *sync.Mutex + sessMu sync.Mutex schemaEventClients sync.Map + preparedCache proxycore.PreparedCache clientIdGen uint64 lb proxycore.LoadBalancer systemLocalValues map[string]message.Column @@ -90,12 +95,9 @@ func (p *Proxy) OnEvent(event interface{}) { func NewProxy(ctx context.Context, config Config) *Proxy { return &Proxy{ - ctx: ctx, - config: config, - logger: proxycore.GetOrCreateNopLogger(config.Logger), - sessions: sync.Map{}, - sessMu: &sync.Mutex{}, - schemaEventClients: sync.Map{}, + ctx: ctx, + config: config, + logger: proxycore.GetOrCreateNopLogger(config.Logger), } } @@ -110,6 +112,11 @@ func (p *Proxy) ListenAndServe(address string) error { func (p *Proxy) Listen(address string) error { var err error + p.preparedCache, err = getOrCreateDefaultPreparedCache(p.config.PreparedCache) + if err != nil { + return fmt.Errorf("unable to create prepared cache %w", err) + } + p.cluster, err = proxycore.ConnectCluster(p.ctx, proxycore.ClusterConfig{ Version: p.config.Version, Auth: p.config.Auth, @@ -141,9 +148,9 @@ func (p *Proxy) Listen(address string) error { NumConns: p.config.NumConns, Version: p.cluster.NegotiatedVersion, Auth: p.config.Auth, - Logger: p.config.Logger, HeartBeatInterval: p.config.HeartBeatInterval, IdleTimeout: p.config.IdleTimeout, + PreparedCache: p.preparedCache, }) if err != nil { @@ -176,6 +183,10 @@ func (p *Proxy) Serve() error { } } +func (p *Proxy) Shutdown() error { + return p.listener.Close() +} + func (p *Proxy) handle(conn *net.TCPConn) { if err := conn.SetKeepAlive(false); err != nil { p.logger.Warn("failed to disable keepalive on connection", zap.Error(err)) @@ -205,6 +216,7 @@ func (p *Proxy) maybeCreateSession(keyspace string) error { NumConns: p.config.NumConns, Version: p.cluster.NegotiatedVersion, Auth: p.config.Auth, + PreparedCache: p.preparedCache, Keyspace: keyspace, HeartBeatInterval: p.config.HeartBeatInterval, IdleTimeout: p.config.IdleTimeout, @@ -316,7 +328,7 @@ func (c *client) execute(raw *frame.RawFrame, idempotent bool) { qp: c.proxy.newQueryPlan(), raw: raw, } - req.execute() + req.Execute(true) } else { c.send(raw.Header, &message.ServerError{ErrorMessage: "Attempted to use invalid keyspace"}) } @@ -498,3 +510,34 @@ func (c *client) send(hdr *frame.Header, msg message.Message) { func (c *client) Closing(_ error) { c.proxy.schemaEventClients.Delete(c.id) } + +func getOrCreateDefaultPreparedCache(cache proxycore.PreparedCache) (proxycore.PreparedCache, error) { + if cache == nil { + return NewDefaultPreparedCache(1e8 / 256) // ~100MB with an average query size of 256 bytes + } + return cache, nil +} + +// NewDefaultPreparedCache creates a new default prepared cache capping the max item capacity to `size`. +func NewDefaultPreparedCache(size int) (proxycore.PreparedCache, error) { + cache, err := lru.New(size) + if err != nil { + return nil, err + } + return &defaultPreparedCache{cache}, nil +} + +type defaultPreparedCache struct { + cache *lru.Cache +} + +func (d defaultPreparedCache) Store(id string, entry *proxycore.PreparedEntry) { + d.cache.Add(id, entry) +} + +func (d defaultPreparedCache) Load(id string) (entry *proxycore.PreparedEntry, ok bool) { + if val, ok := d.cache.Get(id); ok { + return val.(*proxycore.PreparedEntry), true + } + return nil, false +} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index a462053..7b5bd7a 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -19,6 +19,7 @@ import ( "errors" "net" "strconv" + "sync" "testing" "time" @@ -42,6 +43,7 @@ func TestProxy_ListenAndServe(t *testing.T) { const proxyContactPoint = "127.0.0.1:9042" cluster := proxycore.NewMockCluster(net.ParseIP("127.0.0.0"), clusterPort) + defer cluster.Shutdown() cluster.Handlers = proxycore.NewMockRequestHandlers(proxycore.MockRequestHandlers{ primitive.OpCodeQuery: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message { @@ -82,22 +84,25 @@ func TestProxy_ListenAndServe(t *testing.T) { require.NoError(t, err) proxy := NewProxy(ctx, Config{ - Version: primitive.ProtocolVersion4, - Resolver: proxycore.NewResolverWithDefaultPort([]string{clusterContactPoint}, clusterPort), - ReconnectPolicy: proxycore.NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second), - NumConns: 2, + Version: primitive.ProtocolVersion4, + Resolver: proxycore.NewResolverWithDefaultPort([]string{clusterContactPoint}, clusterPort), + ReconnectPolicy: proxycore.NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second), + NumConns: 2, HeartBeatInterval: 30 * time.Second, IdleTimeout: 60 * time.Second, }) err = proxy.Listen(proxyContactPoint) + defer func(proxy *Proxy) { + _ = proxy.Shutdown() + }(proxy) require.NoError(t, err) go func() { _ = proxy.Serve() }() - cl, err := proxycore.ConnectClient(ctx, proxycore.NewEndpoint(proxyContactPoint)) + cl, err := proxycore.ConnectClient(ctx, proxycore.NewEndpoint(proxyContactPoint), proxycore.ClientConnConfig{}) require.NoError(t, err) version, err := cl.Handshake(ctx, primitive.ProtocolVersion4, nil) @@ -128,6 +133,103 @@ func TestProxy_ListenAndServe(t *testing.T) { assert.True(t, added) } +func TestProxy_Unprepared(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numNodes = 3 + + const clusterContactPoint = "127.0.0.1:8000" + const clusterPort = 8000 + + const proxyContactPoint = "127.0.0.1:9042" + const version = primitive.ProtocolVersion4 + + preparedId := []byte("abc") + + cluster := proxycore.NewMockCluster(net.ParseIP("127.0.0.0"), clusterPort) + defer cluster.Shutdown() + + var prepared sync.Map + + cluster.Handlers = proxycore.NewMockRequestHandlers(proxycore.MockRequestHandlers{ + primitive.OpCodePrepare: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message { + prepared.Store(cl.Local().IP, true) + return &message.PreparedResult{ + PreparedQueryId: preparedId, + } + }, + primitive.OpCodeExecute: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message { + if _, ok := prepared.Load(cl.Local().IP); ok { + return &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 0, + }, + Data: message.RowSet{}, + } + } else { + ex := frm.Body.Message.(*message.Execute) + assert.Equal(t, preparedId, ex.QueryId) + return &message.Unprepared{Id: ex.QueryId} + } + }, + }) + + for i := 1; i <= numNodes; i++ { + err := cluster.Add(ctx, i) + require.NoError(t, err) + } + + proxy := NewProxy(ctx, Config{ + Version: version, + Resolver: proxycore.NewResolverWithDefaultPort([]string{clusterContactPoint}, clusterPort), + ReconnectPolicy: proxycore.NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second), + NumConns: 2, + HeartBeatInterval: 30 * time.Second, + IdleTimeout: 60 * time.Second, + }) + + err := proxy.Listen(proxyContactPoint) + defer func(proxy *Proxy) { + _ = proxy.Shutdown() + }(proxy) + require.NoError(t, err) + + go func() { + _ = proxy.Serve() + }() + + cl, err := proxycore.ConnectClient(ctx, proxycore.NewEndpoint(proxyContactPoint), proxycore.ClientConnConfig{}) + require.NoError(t, err) + + negotiated, err := cl.Handshake(ctx, version, nil) + require.NoError(t, err) + assert.Equal(t, version, negotiated) + + // Only prepare on a single node + resp, err := cl.SendAndReceive(ctx, frame.NewFrame(version, 0, &message.Prepare{Query: "SELECT * FROM test.test"})) + require.NoError(t, err) + assert.Equal(t, primitive.OpCodeResult, resp.Header.OpCode) + _, ok := resp.Body.Message.(*message.PreparedResult) + assert.True(t, ok, "expected prepared result") + + for i := 0; i < numNodes; i++ { + resp, err = cl.SendAndReceive(ctx, frame.NewFrame(version, 0, &message.Execute{QueryId: preparedId})) + require.NoError(t, err) + assert.Equal(t, primitive.OpCodeResult, resp.Header.OpCode) + _, ok = resp.Body.Message.(*message.RowsResult) + assert.True(t, ok, "expected rows result") + } + + // Count the number of unique nodes that were prepared + count := 0 + prepared.Range(func(_, _ interface{}) bool { + count++ + return true + }) + assert.Equal(t, numNodes, count) +} + func testQueryHosts(ctx context.Context, cl *proxycore.ClientConn) (map[string]struct{}, error) { hosts := make(map[string]struct{}) for i := 0; i < 3; i++ { diff --git a/proxy/request.go b/proxy/request.go index 33d7fdd..0657a87 100644 --- a/proxy/request.go +++ b/proxy/request.go @@ -19,6 +19,7 @@ import ( "sync" "cql-proxy/proxycore" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "go.uber.org/zap" @@ -29,26 +30,29 @@ type request struct { session *proxycore.Session idempotent bool done bool + host *proxycore.Host stream int16 qp proxycore.QueryPlan raw *frame.RawFrame mu sync.Mutex } -func (r *request) execute() { +func (r *request) Execute(next bool) { r.mu.Lock() defer r.mu.Unlock() for !r.done { - host := r.qp.Next() - if host == nil { + if next { + r.host = r.qp.Next() + } + if r.host == nil { r.done = true r.send(&message.Unavailable{ErrorMessage: "No more hosts available (exhausted query plan)"}) } else { - err := r.session.Send(host, r) + err := r.session.Send(r.host, r) if err == nil { break } else { - r.client.proxy.logger.Debug("failed to send request to host", zap.Stringer("host", host), zap.Error(err)) + r.client.proxy.logger.Debug("failed to send request to host", zap.Stringer("host", r.host), zap.Error(err)) } } } @@ -73,7 +77,7 @@ func (r *request) Frame() interface{} { func (r *request) OnClose(_ error) { if r.idempotent { - r.execute() + r.Execute(true) } else { r.mu.Lock() if !r.done { diff --git a/proxycore/clientconn.go b/proxycore/clientconn.go index f2659ca..d7d19fa 100644 --- a/proxycore/clientconn.go +++ b/proxycore/clientconn.go @@ -16,6 +16,8 @@ package proxycore import ( "context" + "encoding/binary" + "encoding/hex" "errors" "fmt" "io" @@ -47,25 +49,31 @@ func (f EventHandlerFunc) OnEvent(frm *frame.Frame) { f(frm) } -type ClientConn struct { - conn *Conn - inflight int32 - pending *pendingRequests - eventHandler EventHandler - closing bool - closingMu *sync.RWMutex +type ClientConnConfig struct { + PreparedCache PreparedCache + Handler EventHandler + Logger *zap.Logger } -// ConnectClient creates a new connection to an endpoint within a downstream cluster using TLS if specified. -func ConnectClient(ctx context.Context, endpoint Endpoint) (*ClientConn, error) { - return ConnectClientWithEvents(ctx, endpoint, nil) +type ClientConn struct { + conn *Conn + inflight int32 + pending *pendingRequests + eventHandler EventHandler + preparedCache PreparedCache + logger *zap.Logger + closing bool + closingMu *sync.RWMutex } -func ConnectClientWithEvents(ctx context.Context, endpoint Endpoint, handler EventHandler) (*ClientConn, error) { +// ConnectClient creates a new connection to an endpoint within a downstream cluster using TLS if specified. +func ConnectClient(ctx context.Context, endpoint Endpoint, config ClientConnConfig) (*ClientConn, error) { c := &ClientConn{ - pending: newPendingRequests(MaxStreams), - eventHandler: handler, - closingMu: &sync.RWMutex{}, + pending: newPendingRequests(MaxStreams), + eventHandler: config.Handler, + closingMu: &sync.RWMutex{}, + preparedCache: config.PreparedCache, + logger: GetOrCreateNopLogger(config.Logger), } var err error c.conn, err = Connect(ctx, endpoint, c) @@ -257,12 +265,91 @@ func (c *ClientConn) Receive(reader io.Reader) error { return errors.New("invalid stream") } atomic.AddInt32(&c.inflight, -1) - request.OnResult(raw) + + handled := false + + // If we have a prepared cache attempt to recover from unprepared errors and cache previously seen prepared + // requests (so they can be used to prepare other nodes). + if c.preparedCache != nil { + switch raw.Header.OpCode { + case primitive.OpCodeError: + handled = c.maybePrepareAndExecute(request, raw) + case primitive.OpCodeResult: + c.maybeCachePrepared(request, raw) + } + } + + if !handled { + request.OnResult(raw) + } } return nil } +// maybePrepareAndExecute checks the response looking for unprepared errors and attempts to prepare them. +// If an unprepared error is encountered it attempts to prepare the query on the connection and re-execute the original +// request. +func (c *ClientConn) maybePrepareAndExecute(request Request, raw *frame.RawFrame) bool { + code, err := readInt(raw.Body) + if err != nil { + c.logger.Error("failed to read `code` in error response", zap.Error(err)) + return false + } + + if primitive.ErrorCode(code) == primitive.ErrorCodeUnprepared { + frm, err := codec.ConvertFromRawFrame(raw) + if err != nil { + c.logger.Error("failed to decode unprepared error response", zap.Error(err)) + return false + } + msg := frm.Body.Message.(*message.Unprepared) + id := hex.EncodeToString(msg.Id) + if prepare, ok := c.preparedCache.Load(id); ok { + err = c.Send(&prepareRequest{ + prepare: prepare.PreparedFrame, + origRequest: request, + }) + if err != nil { + c.logger.Error("failed to prepare query after receiving an unprepared error response", + zap.String("host", c.conn.RemoteAddr().String()), + zap.String("id", id), + zap.Error(err)) + return false + } else { + return true + } + } else { + c.logger.Warn("received unprepared error response, but existing prepared ID not in the cache", + zap.String("id", id)) + } + } + return false +} + +// maybeCachePrepared checks the response looking for prepared frames and caches the original prepare request. +// This is done so that the prepare request can be used to prepare other nodes that have not been prepared, but are +// attempting to execute a request that has been prepared on another node in the cluster. +func (c *ClientConn) maybeCachePrepared(request Request, raw *frame.RawFrame) { + kind, err := readInt(raw.Body) + if err != nil { + c.logger.Error("failed to read `kind` in result response", zap.Error(err)) + return + } + if primitive.ResultType(kind) == primitive.ResultTypePrepared { + frm, err := codec.ConvertFromRawFrame(raw) + if err != nil { + c.logger.Error("failed to decode prepared result response", zap.Error(err)) + return + } + msg := frm.Body.Message.(*message.PreparedResult) + c.preparedCache.Store(hex.EncodeToString(msg.PreparedQueryId), + &PreparedEntry{ + request.Frame().(*frame.RawFrame), // Store frame so we can re-prepare + }) + } +} + func (c *ClientConn) Closing(err error) { c.closingMu.Lock() c.closing = true @@ -395,6 +482,10 @@ type internalRequest struct { res chan *frame.RawFrame } +func (i *internalRequest) Execute(_ bool) { + panic("not implemented") +} + func (i *internalRequest) Frame() interface{} { return i.frame } @@ -414,3 +505,35 @@ func (i *internalRequest) OnResult(raw *frame.RawFrame) { panic("attempted to set result multiple times") } } + +type prepareRequest struct { + prepare *frame.RawFrame + origRequest Request +} + +func (r *prepareRequest) Execute(_ bool) { + panic("not implemented") +} + +func (r *prepareRequest) Frame() interface{} { + return r.prepare +} + +func (r *prepareRequest) OnClose(err error) { + r.origRequest.OnClose(err) +} + +func (r *prepareRequest) OnResult(raw *frame.RawFrame) { + next := false // If there's no error then we re-try on the original host + if raw.Header.OpCode == primitive.OpCodeError { + next = true // Try the next node + } + r.origRequest.Execute(next) +} + +func readInt(bytes []byte) (int32, error) { + if len(bytes) < 4 { + return 0, errors.New("[int] expects at least 4 bytes") + } + return int32(binary.BigEndian.Uint32(bytes)), nil +} diff --git a/proxycore/clientconn_test.go b/proxycore/clientconn_test.go index 456f2d6..9bfc570 100644 --- a/proxycore/clientconn_test.go +++ b/proxycore/clientconn_test.go @@ -17,6 +17,7 @@ package proxycore import ( "bytes" "context" + "encoding/hex" "net" "sync" "testing" @@ -27,6 +28,7 @@ import ( "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -42,7 +44,7 @@ func TestClientConn_Handshake(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) version, err := cl.Handshake(ctx, primitive.ProtocolVersion4, nil) @@ -65,7 +67,7 @@ func TestClientConn_HandshakeNegotiateProtocolVersion(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) version, err := cl.Handshake(ctx, starting, nil) @@ -89,7 +91,7 @@ func TestClientConn_HandshakePasswordAuth(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, NewPasswordAuth(username, password)) @@ -112,9 +114,11 @@ func TestConnectClientWithEvents(t *testing.T) { require.NoError(t, err) events := make(chan *frame.Frame) - cl, err := ConnectClientWithEvents(ctx, &defaultEndpoint{"127.0.0.1:9042"}, EventHandlerFunc(func(frm *frame.Frame) { - events <- frm - })) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{ + Handler: EventHandlerFunc(func(frm *frame.Frame) { + events <- frm + }), + }) require.NoError(t, err) wait := func() *frame.Frame { @@ -158,7 +162,7 @@ func TestClientConn_HandshakePasswordInvalidAuth(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, NewPasswordAuth("invalid", "invalid")) @@ -182,7 +186,7 @@ func TestClientConn_HandshakeAuthRequireButNotProvided(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, starting, nil) @@ -206,7 +210,7 @@ func TestClientConn_Query(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, nil) @@ -255,7 +259,7 @@ func TestClientConn_SetKeyspace(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, nil) @@ -309,7 +313,7 @@ func TestClientConn_Inflight(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, nil) @@ -330,10 +334,199 @@ func TestClientConn_Inflight(t *testing.T) { assert.Equal(t, int32(0), cl.Inflight()) // Should be 0 after they complete } +func TestClientConn_Unprepared(t *testing.T) { + const ( + Unprepared int32 = iota + UnpreparedError + Prepared + Executed + ) + + preparedId := []byte("abc") + + state := atomic.NewInt32(Unprepared) + + server := &MockServer{ + Handlers: NewMockRequestHandlers(MockRequestHandlers{ + primitive.OpCodePrepare: func(cl *MockClient, frm *frame.Frame) message.Message { + require.True(t, state.CAS(UnpreparedError, Prepared), "expected the query to be prepared as the result of an unprepared error") + return &message.PreparedResult{ + PreparedQueryId: preparedId, + } + }, + primitive.OpCodeExecute: func(cl *MockClient, frm *frame.Frame) message.Message { + if state.CAS(Unprepared, UnpreparedError) { + ex := frm.Body.Message.(*message.Execute) + require.Equal(t, preparedId, ex.QueryId) + return &message.Unprepared{Id: preparedId} + } else if state.CAS(Prepared, Executed) { + return &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 0, + }, + Data: message.RowSet{}, + } + } else { + return &message.ServerError{ErrorMessage: "expected the query to be either unprepared or prepared"} + } + }, + }), + } + + const supported = primitive.ProtocolVersion4 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := server.Serve(ctx, supported, MockHost{ + IP: "127.0.0.1", + Port: 9042, + HostID: mockHostID, + }, nil) + require.NoError(t, err) + + // Pre-populate the prepared cache as if the query was prepared, but on another node + prepareFrame, err := codec.ConvertToRawFrame(frame.NewFrame(supported, 0, &message.Prepare{Query: "SELECT * FROM test.test"})) + require.NoError(t, err) + + var preparedCache testPrepareCache + preparedCache.Store(hex.EncodeToString(preparedId), &PreparedEntry{prepareFrame}) + + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{PreparedCache: &preparedCache}) + defer func(cl *ClientConn) { + _ = cl.Close() + }(cl) + require.NoError(t, err) + + _, err = cl.Handshake(ctx, supported, nil) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) // + + err = cl.Send(&testPrepareRequest{ + t: t, + wg: &wg, + cl: cl, + version: supported, + preparedId: preparedId, + }) + require.NoError(t, err) + + wg.Wait() +} + +func TestClientConn_UnpreparedNotCached(t *testing.T) { + preparedId := []byte("abc") + + server := &MockServer{ + Handlers: NewMockRequestHandlers(MockRequestHandlers{ + primitive.OpCodePrepare: func(cl *MockClient, frm *frame.Frame) message.Message { + require.Fail(t, "prepare was never cached so this shouldn't happen") + return &message.PreparedResult{ + PreparedQueryId: preparedId, + } + }, + primitive.OpCodeExecute: func(cl *MockClient, frm *frame.Frame) message.Message { + ex := frm.Body.Message.(*message.Execute) + require.Equal(t, preparedId, ex.QueryId) + return &message.Unprepared{Id: ex.QueryId} + }, + }), + } + + const supported = primitive.ProtocolVersion4 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := server.Serve(ctx, supported, MockHost{ + IP: "127.0.0.1", + Port: 9042, + HostID: mockHostID, + }, nil) + require.NoError(t, err) + + logger, _ := zap.NewDevelopment() + + var preparedCache testPrepareCache + + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, + ClientConnConfig{ + PreparedCache: &preparedCache, // Empty cache + Logger: logger, + }, + ) + defer func(cl *ClientConn) { + _ = cl.Close() + }(cl) + require.NoError(t, err) + + _, err = cl.Handshake(ctx, supported, nil) + require.NoError(t, err) + + resp, err := cl.SendAndReceive(ctx, frame.NewFrame(supported, 0, &message.Execute{QueryId: preparedId})) + require.NoError(t, err) + + assert.Equal(t, primitive.OpCodeError, resp.Header.OpCode) + + _, ok := resp.Body.Message.(*message.Unprepared) + assert.True(t, ok, "expecting an unprepared response") +} + +type testPrepareCache struct { + cache sync.Map +} + +func (t *testPrepareCache) Store(id string, entry *PreparedEntry) { + t.cache.Store(id, entry) +} + +func (t *testPrepareCache) Load(id string) (entry *PreparedEntry, ok bool) { + if val, ok := t.cache.Load(id); ok { + return val.(*PreparedEntry), true + } + return nil, false +} + +type testPrepareRequest struct { + t *testing.T + wg *sync.WaitGroup + cl *ClientConn + version primitive.ProtocolVersion + preparedId []byte +} + +func (t *testPrepareRequest) Frame() interface{} { + return frame.NewFrame(t.version, 0, &message.Execute{QueryId: t.preparedId}) +} + +func (t *testPrepareRequest) Execute(next bool) { + err := t.cl.Send(t) + require.NoError(t.t, err) +} + +func (t *testPrepareRequest) OnClose(_ error) { + panic("not implemented") +} + +func (t *testPrepareRequest) OnResult(raw *frame.RawFrame) { + assert.Equal(t.t, primitive.OpCodeResult, raw.Header.OpCode) + frm, err := codec.ConvertFromRawFrame(raw) + require.NoError(t.t, err) + _, ok := frm.Body.Message.(*message.RowsResult) + assert.True(t.t, ok) + t.wg.Done() +} + type testInflightRequest struct { wg *sync.WaitGroup } +func (t testInflightRequest) Execute(_ bool) { + panic("not implemented") +} + func (t testInflightRequest) Frame() interface{} { return frame.NewFrame(primitive.ProtocolVersion4, -1, &message.Query{ Query: "SELECT * FROM system.local", @@ -341,6 +534,7 @@ func (t testInflightRequest) Frame() interface{} { } func (t testInflightRequest) OnClose(_ error) { + panic("not implemented") } func (t testInflightRequest) OnResult(_ *frame.RawFrame) { @@ -403,7 +597,7 @@ func TestClientConn_Heartbeats(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, nil) @@ -441,7 +635,7 @@ func TestClientConn_HeartbeatsError(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, nil) @@ -483,7 +677,7 @@ func TestClientConn_HeartbeatsTimeout(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, nil) @@ -525,7 +719,7 @@ func TestClientConn_HeartbeatsUnexpectedMessage(t *testing.T) { }, nil) require.NoError(t, err) - cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}) + cl, err := ConnectClient(ctx, &defaultEndpoint{"127.0.0.1:9042"}, ClientConnConfig{}) require.NoError(t, err) _, err = cl.Handshake(ctx, supported, nil) diff --git a/proxycore/cluster.go b/proxycore/cluster.go index 8da95a1..fbd928d 100644 --- a/proxycore/cluster.go +++ b/proxycore/cluster.go @@ -157,7 +157,7 @@ func (c *Cluster) connect(ctx context.Context, endpoint Endpoint, initial bool) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - conn, err := ConnectClientWithEvents(ctx, endpoint, c) + conn, err := ConnectClient(ctx, endpoint, ClientConnConfig{Handler: c, Logger: c.logger}) if err != nil { return err } diff --git a/proxycore/connpool.go b/proxycore/connpool.go index 6e6cd95..a599b74 100644 --- a/proxycore/connpool.go +++ b/proxycore/connpool.go @@ -32,13 +32,14 @@ type connPoolConfig struct { } type connPool struct { - ctx context.Context - config connPoolConfig - logger *zap.Logger - cancel context.CancelFunc - remaining int32 - conns []*ClientConn - connsMu *sync.RWMutex + ctx context.Context + config connPoolConfig + logger *zap.Logger + preparedCache PreparedCache + cancel context.CancelFunc + remaining int32 + conns []*ClientConn + connsMu *sync.RWMutex } // connectPool establishes a pool of connections to a given endpoint within a downstream cluster. These connection pools will @@ -47,13 +48,14 @@ func connectPool(ctx context.Context, config connPoolConfig) (*connPool, error) ctx, cancel := context.WithCancel(ctx) pool := &connPool{ - ctx: ctx, - config: config, - logger: GetOrCreateNopLogger(config.Logger), - cancel: cancel, - remaining: int32(config.NumConns), - conns: make([]*ClientConn, config.NumConns), - connsMu: &sync.RWMutex{}, + ctx: ctx, + config: config, + logger: GetOrCreateNopLogger(config.Logger), + preparedCache: config.PreparedCache, + cancel: cancel, + remaining: int32(config.NumConns), + conns: make([]*ClientConn, config.NumConns), + connsMu: &sync.RWMutex{}, } errs := make([]error, config.NumConns) @@ -133,7 +135,9 @@ func (p *connPool) connect() (conn *ClientConn, err error) { timeout := getOrUseDefault(p.config.ConnectTimeout, DefaultConnectTimeout) ctx, cancel := context.WithTimeout(p.ctx, timeout) defer cancel() - conn, err = ConnectClient(ctx, p.config.Endpoint) + conn, err = ConnectClient(ctx, p.config.Endpoint, ClientConnConfig{ + PreparedCache: p.preparedCache, + Logger: p.logger}) if err != nil { return nil, err } diff --git a/proxycore/mockcluster.go b/proxycore/mockcluster.go index aaaec05..49a8715 100644 --- a/proxycore/mockcluster.go +++ b/proxycore/mockcluster.go @@ -28,6 +28,7 @@ import ( "time" "cql-proxy/parser" + "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" @@ -235,6 +236,7 @@ func (c MockClient) Closing(_ error) { } type MockServer struct { + wg sync.WaitGroup cancel context.CancelFunc clients sync.Map clientIdGen uint64 @@ -282,6 +284,11 @@ func (s *MockServer) Remove(host MockHost) { }) } +func (s *MockServer) Shutdown() { + s.cancel() + s.wg.Wait() +} + func (s *MockServer) Event(evt message.Event) { s.clients.Range(func(_, value interface{}) bool { cl := value.(*MockClient) @@ -337,10 +344,13 @@ func (s *MockServer) Serve(ctx context.Context, maxVersion primitive.ProtocolVer copy(s.peers, peers) s.peers = removeHost(s.peers, local) + s.wg.Add(1) + go func() { for { c, err := listener.Accept() if err != nil { + s.wg.Done() break } id := atomic.AddUint64(&s.clientIdGen, 1) @@ -364,7 +374,6 @@ func (s *MockServer) Serve(ctx context.Context, maxVersion primitive.ProtocolVer }(cl) s.clients.Store(id, cl) cl.conn.Start() - } }() @@ -470,6 +479,14 @@ func (c *MockCluster) maybeStop(host MockHost) { func (c *MockCluster) Stop(n int) { c.maybeStop(c.generate(n)) } + +func (c *MockCluster) Shutdown() { + for _, server := range c.servers { + server.Shutdown() + } +} + + func makeSystemLocalValues(version primitive.ProtocolVersion, address string, hostID, schemaVersion *primitive.UUID) map[string]message.Column { ip := net.ParseIP(address) values := makeSystemValues(version, ip, hostID, schemaVersion) diff --git a/proxycore/requests.go b/proxycore/requests.go index 0a0e7db..9c29683 100644 --- a/proxycore/requests.go +++ b/proxycore/requests.go @@ -15,13 +15,31 @@ package proxycore import ( - "github.com/datastax/go-cassandra-native-protocol/frame" "sync" + + "github.com/datastax/go-cassandra-native-protocol/frame" ) +// Request represents the data frame and lifecycle of a CQL native protocol request. type Request interface { + // Frame returns the frame to be executed as part of the request. + // This must be idempotent. Frame() interface{} + + // Execute is called when a request need to be retried. + // This is currently only called for executing prepared requests (i.e. `EXECUTE` request frames). If `EXECUTE` + // request frames are not expected then the implementation should `panic()`. + // + // If `next` is false then the request must be retried on the current node; otherwise, it should be retried on + // another node which is usually then next node in a query plan. + Execute(next bool) + + // OnClose is called when the underlying connection is closed. + // No assumptions should be made about whether the request has been successfully sent; it is possible that + // the request has been fully sent and no response was received before OnClose(err error) + + // OnResult is called when a response frame has been sent back from the connection. OnResult(raw *frame.RawFrame) } diff --git a/proxycore/requests_test.go b/proxycore/requests_test.go index 9cc3619..09c5af3 100644 --- a/proxycore/requests_test.go +++ b/proxycore/requests_test.go @@ -15,10 +15,11 @@ package proxycore import ( - "github.com/datastax/go-cassandra-native-protocol/frame" - "github.com/stretchr/testify/assert" "io" "testing" + + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/stretchr/testify/assert" ) func TestPendingRequests(t *testing.T) { @@ -57,8 +58,12 @@ type testPendingRequest struct { errs *[]error } +func (t testPendingRequest) Execute(_ bool) { + panic("not implemented") +} + func (t testPendingRequest) Frame() interface{} { - panic("implement me") + panic("not implemented") } func (t *testPendingRequest) OnClose(err error) { @@ -66,5 +71,5 @@ func (t *testPendingRequest) OnClose(err error) { } func (t testPendingRequest) OnResult(_ *frame.RawFrame) { - panic("implement me") + panic("not implemented") } diff --git a/proxycore/session.go b/proxycore/session.go index 03d2fca..7d601c2 100644 --- a/proxycore/session.go +++ b/proxycore/session.go @@ -20,6 +20,7 @@ import ( "sync" "time" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/primitive" "go.uber.org/zap" ) @@ -28,6 +29,19 @@ var ( NoConnForHost = errors.New("no connection available for host") ) +// PreparedEntry is an entry in the prepared cache. +type PreparedEntry struct { + PreparedFrame *frame.RawFrame +} + +// PreparedCache a thread-safe cache for storing prepared queries. +type PreparedCache interface { + // Store add an entry to the cache. + Store(id string, entry *PreparedEntry) + // Load retrieves an entry from the cache. `ok` is true if the entry is present; otherwise it's false. + Load(id string) (entry *PreparedEntry, ok bool) +} + type SessionConfig struct { ReconnectPolicy ReconnectPolicy NumConns int @@ -35,7 +49,9 @@ type SessionConfig struct { Version primitive.ProtocolVersion Auth Authenticator Logger *zap.Logger - ConnectTimeout time.Duration + // PreparedCache a global cache share across sessions for storing previously prepared queries + PreparedCache PreparedCache + ConnectTimeout time.Duration HeartBeatInterval time.Duration IdleTimeout time.Duration } diff --git a/proxycore/session_test.go b/proxycore/session_test.go index 6a43e9e..84f4be0 100644 --- a/proxycore/session_test.go +++ b/proxycore/session_test.go @@ -119,6 +119,10 @@ type testSessionRequest struct { wg *sync.WaitGroup } +func (r testSessionRequest) Execute(next bool) { + panic("not implemented") +} + func (r testSessionRequest) Frame() interface{} { return frame.NewFrame(primitive.ProtocolVersion4, -1, &message.Query{ Query: "SELECT * FROM system.local",