From 6370ae03be8b4fbfa90b90ac9aa9133b1ce4b8f8 Mon Sep 17 00:00:00 2001 From: Bohdan Siryk Date: Wed, 30 Oct 2024 18:01:09 +0200 Subject: [PATCH] 1. Updated the way how the driver constructs stmt cache keys. The current code base uses initial keyspace provided by the user to construct the keys. Since proto v5 we also should account for keyspace bounding for a specific query, so the driver should use the bounded keyspace instead of the initial to construct the key. 2. Changed the way how routing key cache keys are constructed to account the keyspace overriding as well. --- cassandra_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++-- conn.go | 30 ++++++++----- session.go | 21 ++++++---- 3 files changed, 134 insertions(+), 22 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index 773bb288c..dfc96feb8 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -1483,7 +1483,7 @@ func TestQueryInfo(t *testing.T) { defer session.Close() conn := getRandomConn(t, session) - info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil) + info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil, conn.currentKeyspace) if err != nil { t.Fatalf("Failed to execute query for preparing statement: %v", err) @@ -2602,7 +2602,7 @@ func TestRoutingKey(t *testing.T) { t.Fatalf("failed to create table with error '%v'", err) } - routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -2626,7 +2626,7 @@ func TestRoutingKey(t *testing.T) { } // verify the cache is working - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -2660,7 +2660,7 @@ func TestRoutingKey(t *testing.T) { t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey) } - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -3606,3 +3606,100 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { require.Equal(t, preparedStatementAfterTableAltering2.resultMetadataID, preparedStatementAfterTableAltering3.resultMetadataID) require.Equal(t, preparedStatementAfterTableAltering2.response, preparedStatementAfterTableAltering3.response) } + +func TestStmtCacheUsesOverriddenKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }` + + err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_stmt_cache")) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_stmt_cache.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + const insertQuery = "INSERT INTO stmt_cache_uses_overridden_ks (id) VALUES (?)" + + // Inserting data via Batch to ensure that batches + // properly accounts for keyspace overriding + b1 := session.NewBatch(LoggedBatch) + b1.Query(insertQuery, 1) + err = session.ExecuteBatch(b1) + require.NoError(t, err) + + b2 := session.NewBatch(LoggedBatch) + b2.SetKeyspace("gocql_test_stmt_cache") + b2.Query(insertQuery, 2) + err = session.ExecuteBatch(b2) + require.NoError(t, err) + + var scannedID int + + const selectStmt = "SELECT * FROM stmt_cache_uses_overridden_ks" + + // By default in our test suite session uses gocql_test ks + err = session.Query(selectStmt).Scan(&scannedID) + require.NoError(t, err) + require.Equal(t, 1, scannedID) + + scannedID = 0 + err = session.Query(selectStmt).SetKeyspace("gocql_test_stmt_cache").Scan(&scannedID) + require.NoError(t, err) + require.Equal(t, 2, scannedID) + + session.Query("DROP KEYSPACE IF EXISTS gocql_test_stmt_cache").Exec() +} + +func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }` + + err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_routing_key_cache")) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_routing_key_cache.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks WHERE id=?" + + q1 := session.Query(selectStmt, 1) + _, err = q1.GetRoutingKey() + require.NoError(t, err) + require.Equal(t, "gocql_test", q1.routingInfo.keyspace) + + q2 := session.Query(selectStmt, 1) + _, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey() + require.NoError(t, err) + require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace) + + session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec() +} diff --git a/conn.go b/conn.go index 7c2728ba0..7a50a732b 100644 --- a/conn.go +++ b/conn.go @@ -1410,8 +1410,8 @@ type inflightPrepare struct { preparedStatment *preparedStatment } -func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) +func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), keyspace, stmt) flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { flight := &inflightPrepare{ done: make(chan struct{}), @@ -1486,10 +1486,6 @@ func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tra } } -func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { - return c.prepareStatementForKeyspace(ctx, stmt, tracer, c.currentKeyspace) -} - func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { if named, ok := value.(*namedValue); ok { dst.name = named.name @@ -1531,6 +1527,13 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.nowInSeconds = qry.nowInSecondsValue } + // If a keyspace for the qry is overriden, + // then we should use it to create stmt cache key + usedKeyspace := c.currentKeyspace + if qry.keyspace != "" { + usedKeyspace = qry.keyspace + } + var ( frame frameBuilder info *preparedStatment @@ -1539,7 +1542,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if !qry.skipPrepare && qry.shouldPrepare() { // Prepare all DML queries. Other queries can not be prepared. var err error - info, err = c.prepareStatementForKeyspace(ctx, qry.stmt, qry.trace, qry.keyspace) + info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace) if err != nil { return &Iter{err: err} } @@ -1616,7 +1619,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // If a RESULT/Rows message reports // changed resultset metadata with the Metadata_changed flag, the reported new // resultset metadata must be used in subsequent executions - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey) if ok { newInflight := &inflightPrepare{ @@ -1685,7 +1688,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) return c.executeQuery(ctx, qry) case error: @@ -1767,6 +1770,11 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { req.nowInSeconds = batch.nowInSeconds } + usedKeyspace := c.currentKeyspace + if batch.keyspace != "" { + usedKeyspace = batch.keyspace + } + stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { @@ -1774,7 +1782,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { b := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatementForKeyspace(batch.Context(), entry.Stmt, batch.trace, batch.keyspace) + info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace) if err != nil { return &Iter{err: err} } @@ -1836,7 +1844,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { - key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) + key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } return c.executeBatch(ctx, batch) diff --git a/session.go b/session.go index 2175c28e2..08326f6de 100644 --- a/session.go +++ b/session.go @@ -591,11 +591,18 @@ func (s *Session) getConn() *Conn { return nil } -// returns routing key indexes and type info -func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) { +// Returns routing key indexes and type info. +// If keyspace == "" it uses the keyspace which is specified in Cluster.Keyspace +func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace string) (*routingKeyInfo, error) { + if keyspace == "" { + keyspace = s.cfg.Keyspace + } + s.routingKeyInfoCache.mu.Lock() - entry, cached := s.routingKeyInfoCache.lru.Get(stmt) + // Using here keyspace + stmt as a cache key because + // the query keyspace could be overridden via SetKeyspace + entry, cached := s.routingKeyInfoCache.lru.Get(keyspace + stmt) if cached { // done accessing the cache s.routingKeyInfoCache.mu.Unlock() @@ -635,7 +642,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } // get the query info for the statement - info, inflight.err = conn.prepareStatement(ctx, stmt, nil) + info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace) if inflight.err != nil { // don't cache this error s.routingKeyInfoCache.Remove(stmt) @@ -651,7 +658,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } table := info.request.table - keyspace := info.request.keyspace + keyspace = info.request.keyspace if len(info.request.pkeyColumns) > 0 { // proto v4 dont need to calculate primary key columns @@ -1177,7 +1184,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) { } // try to determine the routing key - routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt) + routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt, q.keyspace) if err != nil { return nil, err } @@ -2009,7 +2016,7 @@ func (b *Batch) GetRoutingKey() ([]byte, error) { return nil, nil } // try to determine the routing key - routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt) + routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.keyspace) if err != nil { return nil, err }