Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prepared queries #28

Merged
merged 14 commits into from
Sep 3, 2021
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/alecthomas/kong v0.2.17
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec
github.com/datastax/go-cassandra-native-protocol v0.0.0-20210604174339-4311e5d5654d
github.com/hashicorp/golang-lru v0.5.4
github.com/stretchr/testify v1.7.0
go.uber.org/atomic v1.8.0 // indirect
go.uber.org/multierr v1.7.0 // indirect
Expand Down
16 changes: 3 additions & 13 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
github.com/alecthomas/kong v0.2.17 h1:URDISCI96MIgcIlQyoCAlhOmrSw6pZScBNkctg8r0W0=
github.com/alecthomas/kong v0.2.17/go.mod h1:ka3VZ8GZNPXv9Ov+j4YNLkI8mTuhXyr/0ktSlqIydQQ=

github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec h1:EEyRvzmpEUZ+I8WmD5cw/vY8EqhambkOqy5iFr0908A=
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
Expand All @@ -11,8 +10,9 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw=
github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/pierrec/lz4/v4 v4.0.3 h1:vNQKSVZNYUEAvRY9FaUXAF1XPbSOHJtDTiP41kzDz2E=

github.com/pierrec/lz4/v4 v4.0.3/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
Expand All @@ -21,14 +21,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo=


github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.8.0 h1:CUhrE4N1rqSE6FM9ecihEjRkLQu8cDfgDyoOs83mEY4=
go.uber.org/atomic v1.8.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
Expand All @@ -38,23 +35,16 @@ go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95a
go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U=
go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=

golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=

golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=

gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
5 changes: 3 additions & 2 deletions proxy/codecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ import (
"encoding/hex"
"errors"
"fmt"
"io"

"github.com/datastax/go-cassandra-native-protocol/frame"
"github.com/datastax/go-cassandra-native-protocol/message"
"github.com/datastax/go-cassandra-native-protocol/primitive"
"io"
)

var codec = frame.NewRawCodec(&partialQueryCodec{}, &partialExecuteCodec{})
Expand Down Expand Up @@ -110,4 +111,4 @@ func (c *partialExecuteCodec) Decode(source io.Reader, _ primitive.ProtocolVersi

func (c *partialExecuteCodec) GetOpCode() primitive.OpCode {
return primitive.OpCodeExecute
}
}
61 changes: 52 additions & 9 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/datastax/go-cassandra-native-protocol/frame"
"github.com/datastax/go-cassandra-native-protocol/message"
"github.com/datastax/go-cassandra-native-protocol/primitive"
lru "github.com/hashicorp/golang-lru"
"go.uber.org/zap"
)

Expand All @@ -50,6 +51,9 @@ type Config struct {
Logger *zap.Logger
HeartBeatInterval time.Duration
IdleTimeout time.Duration
// PreparedCache a cache that stores prepared queries. If not set it uses the default implementation with a max
// capacity of ~100MB.
PreparedCache proxycore.PreparedCache
}

type Proxy struct {
Expand All @@ -59,8 +63,9 @@ type Proxy struct {
listener *net.TCPListener
cluster *proxycore.Cluster
sessions sync.Map
sessMu *sync.Mutex
sessMu sync.Mutex
schemaEventClients sync.Map
preparedCache proxycore.PreparedCache
clientIdGen uint64
lb proxycore.LoadBalancer
systemLocalValues map[string]message.Column
Expand Down Expand Up @@ -90,12 +95,9 @@ func (p *Proxy) OnEvent(event interface{}) {

func NewProxy(ctx context.Context, config Config) *Proxy {
return &Proxy{
ctx: ctx,
config: config,
logger: proxycore.GetOrCreateNopLogger(config.Logger),
sessions: sync.Map{},
sessMu: &sync.Mutex{},
schemaEventClients: sync.Map{},
ctx: ctx,
config: config,
logger: proxycore.GetOrCreateNopLogger(config.Logger),
}
}

Expand All @@ -110,6 +112,11 @@ func (p *Proxy) ListenAndServe(address string) error {
func (p *Proxy) Listen(address string) error {
var err error

p.preparedCache, err = getOrCreateDefaultPreparedCache(p.config.PreparedCache)
if err != nil {
return fmt.Errorf("unable to create prepared cache %w", err)
}

p.cluster, err = proxycore.ConnectCluster(p.ctx, proxycore.ClusterConfig{
Version: p.config.Version,
Auth: p.config.Auth,
Expand Down Expand Up @@ -141,9 +148,9 @@ func (p *Proxy) Listen(address string) error {
NumConns: p.config.NumConns,
Version: p.cluster.NegotiatedVersion,
Auth: p.config.Auth,
Logger: p.config.Logger,
HeartBeatInterval: p.config.HeartBeatInterval,
IdleTimeout: p.config.IdleTimeout,
PreparedCache: p.preparedCache,
})

if err != nil {
Expand Down Expand Up @@ -176,6 +183,10 @@ func (p *Proxy) Serve() error {
}
}

func (p *Proxy) Shutdown() error {
return p.listener.Close()
}

func (p *Proxy) handle(conn *net.TCPConn) {
if err := conn.SetKeepAlive(false); err != nil {
p.logger.Warn("failed to disable keepalive on connection", zap.Error(err))
Expand Down Expand Up @@ -205,6 +216,7 @@ func (p *Proxy) maybeCreateSession(keyspace string) error {
NumConns: p.config.NumConns,
Version: p.cluster.NegotiatedVersion,
Auth: p.config.Auth,
PreparedCache: p.preparedCache,
Keyspace: keyspace,
HeartBeatInterval: p.config.HeartBeatInterval,
IdleTimeout: p.config.IdleTimeout,
Expand Down Expand Up @@ -316,7 +328,7 @@ func (c *client) execute(raw *frame.RawFrame, idempotent bool) {
qp: c.proxy.newQueryPlan(),
raw: raw,
}
req.execute()
req.Execute(true)
} else {
c.send(raw.Header, &message.ServerError{ErrorMessage: "Attempted to use invalid keyspace"})
}
Expand Down Expand Up @@ -498,3 +510,34 @@ func (c *client) send(hdr *frame.Header, msg message.Message) {
func (c *client) Closing(_ error) {
c.proxy.schemaEventClients.Delete(c.id)
}

func getOrCreateDefaultPreparedCache(cache proxycore.PreparedCache) (proxycore.PreparedCache, error) {
if cache == nil {
return NewDefaultPreparedCache(1e8 / 256) // ~100MB with an average query size of 256 bytes
}
return cache, nil
}

// NewDefaultPreparedCache creates a new default prepared cache capping the max item capacity to `size`.
func NewDefaultPreparedCache(size int) (proxycore.PreparedCache, error) {
cache, err := lru.New(size)
if err != nil {
return nil, err
}
return &defaultPreparedCache{cache}, nil
}

type defaultPreparedCache struct {
cache *lru.Cache
}

func (d defaultPreparedCache) Store(id string, entry *proxycore.PreparedEntry) {
d.cache.Add(id, entry)
}

func (d defaultPreparedCache) Load(id string) (entry *proxycore.PreparedEntry, ok bool) {
if val, ok := d.cache.Get(id); ok {
return val.(*proxycore.PreparedEntry), true
}
return nil, false
}
112 changes: 107 additions & 5 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"net"
"strconv"
"sync"
"testing"
"time"

Expand All @@ -42,6 +43,7 @@ func TestProxy_ListenAndServe(t *testing.T) {
const proxyContactPoint = "127.0.0.1:9042"

cluster := proxycore.NewMockCluster(net.ParseIP("127.0.0.0"), clusterPort)
defer cluster.Shutdown()

cluster.Handlers = proxycore.NewMockRequestHandlers(proxycore.MockRequestHandlers{
primitive.OpCodeQuery: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message {
Expand Down Expand Up @@ -82,22 +84,25 @@ func TestProxy_ListenAndServe(t *testing.T) {
require.NoError(t, err)

proxy := NewProxy(ctx, Config{
Version: primitive.ProtocolVersion4,
Resolver: proxycore.NewResolverWithDefaultPort([]string{clusterContactPoint}, clusterPort),
ReconnectPolicy: proxycore.NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second),
NumConns: 2,
Version: primitive.ProtocolVersion4,
Resolver: proxycore.NewResolverWithDefaultPort([]string{clusterContactPoint}, clusterPort),
ReconnectPolicy: proxycore.NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second),
NumConns: 2,
HeartBeatInterval: 30 * time.Second,
IdleTimeout: 60 * time.Second,
})

err = proxy.Listen(proxyContactPoint)
defer func(proxy *Proxy) {
_ = proxy.Shutdown()
}(proxy)
require.NoError(t, err)

go func() {
_ = proxy.Serve()
}()

cl, err := proxycore.ConnectClient(ctx, proxycore.NewEndpoint(proxyContactPoint))
cl, err := proxycore.ConnectClient(ctx, proxycore.NewEndpoint(proxyContactPoint), proxycore.ClientConnConfig{})
require.NoError(t, err)

version, err := cl.Handshake(ctx, primitive.ProtocolVersion4, nil)
Expand Down Expand Up @@ -128,6 +133,103 @@ func TestProxy_ListenAndServe(t *testing.T) {
assert.True(t, added)
}

func TestProxy_Unprepared(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

const numNodes = 3

const clusterContactPoint = "127.0.0.1:8000"
const clusterPort = 8000

const proxyContactPoint = "127.0.0.1:9042"
const version = primitive.ProtocolVersion4

preparedId := []byte("abc")

cluster := proxycore.NewMockCluster(net.ParseIP("127.0.0.0"), clusterPort)
defer cluster.Shutdown()

var prepared sync.Map

cluster.Handlers = proxycore.NewMockRequestHandlers(proxycore.MockRequestHandlers{
primitive.OpCodePrepare: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message {
prepared.Store(cl.Local().IP, true)
return &message.PreparedResult{
PreparedQueryId: preparedId,
}
},
primitive.OpCodeExecute: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message {
if _, ok := prepared.Load(cl.Local().IP); ok {
return &message.RowsResult{
Metadata: &message.RowsMetadata{
ColumnCount: 0,
},
Data: message.RowSet{},
}
} else {
ex := frm.Body.Message.(*message.Execute)
assert.Equal(t, preparedId, ex.QueryId)
return &message.Unprepared{Id: ex.QueryId}
}
},
})

for i := 1; i <= numNodes; i++ {
err := cluster.Add(ctx, i)
require.NoError(t, err)
}

proxy := NewProxy(ctx, Config{
Version: version,
Resolver: proxycore.NewResolverWithDefaultPort([]string{clusterContactPoint}, clusterPort),
ReconnectPolicy: proxycore.NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second),
NumConns: 2,
HeartBeatInterval: 30 * time.Second,
IdleTimeout: 60 * time.Second,
})

err := proxy.Listen(proxyContactPoint)
defer func(proxy *Proxy) {
_ = proxy.Shutdown()
}(proxy)
require.NoError(t, err)

go func() {
_ = proxy.Serve()
}()

cl, err := proxycore.ConnectClient(ctx, proxycore.NewEndpoint(proxyContactPoint), proxycore.ClientConnConfig{})
require.NoError(t, err)

negotiated, err := cl.Handshake(ctx, version, nil)
require.NoError(t, err)
assert.Equal(t, version, negotiated)

// Only prepare on a single node
resp, err := cl.SendAndReceive(ctx, frame.NewFrame(version, 0, &message.Prepare{Query: "SELECT * FROM test.test"}))
require.NoError(t, err)
assert.Equal(t, primitive.OpCodeResult, resp.Header.OpCode)
_, ok := resp.Body.Message.(*message.PreparedResult)
assert.True(t, ok, "expected prepared result")

for i := 0; i < numNodes; i++ {
resp, err = cl.SendAndReceive(ctx, frame.NewFrame(version, 0, &message.Execute{QueryId: preparedId}))
require.NoError(t, err)
assert.Equal(t, primitive.OpCodeResult, resp.Header.OpCode)
_, ok = resp.Body.Message.(*message.RowsResult)
assert.True(t, ok, "expected rows result")
}

// Count the number of unique nodes that were prepared
count := 0
prepared.Range(func(_, _ interface{}) bool {
count++
return true
})
assert.Equal(t, numNodes, count)
}

func testQueryHosts(ctx context.Context, cl *proxycore.ClientConn) (map[string]struct{}, error) {
hosts := make(map[string]struct{})
for i := 0; i < 3; i++ {
Expand Down
Loading