diff --git a/cassandra_test.go b/cassandra_test.go index 773bb288c..d0e919746 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -1483,7 +1483,7 @@ func TestQueryInfo(t *testing.T) { defer session.Close() conn := getRandomConn(t, session) - info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil) + info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil, conn.currentKeyspace) if err != nil { t.Fatalf("Failed to execute query for preparing statement: %v", err) @@ -2602,7 +2602,7 @@ func TestRoutingKey(t *testing.T) { t.Fatalf("failed to create table with error '%v'", err) } - routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -2626,7 +2626,7 @@ func TestRoutingKey(t *testing.T) { } // verify the cache is working - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -2660,7 +2660,7 @@ func TestRoutingKey(t *testing.T) { t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey) } - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -3606,3 +3606,135 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { require.Equal(t, preparedStatementAfterTableAltering2.resultMetadataID, preparedStatementAfterTableAltering3.resultMetadataID) require.Equal(t, preparedStatementAfterTableAltering2.response, preparedStatementAfterTableAltering3.response) } + +func TestStmtCacheUsesOverriddenKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }` + + err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_stmt_cache")) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_stmt_cache.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + const insertQuery = "INSERT INTO stmt_cache_uses_overridden_ks (id) VALUES (?)" + + // Inserting data via Batch to ensure that batches + // properly accounts for keyspace overriding + b1 := session.NewBatch(LoggedBatch) + b1.Query(insertQuery, 1) + err = session.ExecuteBatch(b1) + require.NoError(t, err) + + b2 := session.NewBatch(LoggedBatch) + b2.SetKeyspace("gocql_test_stmt_cache") + b2.Query(insertQuery, 2) + err = session.ExecuteBatch(b2) + require.NoError(t, err) + + var scannedID int + + const selectStmt = "SELECT * FROM stmt_cache_uses_overridden_ks" + + // By default in our test suite session uses gocql_test ks + err = session.Query(selectStmt).Scan(&scannedID) + require.NoError(t, err) + require.Equal(t, 1, scannedID) + + scannedID = 0 + err = session.Query(selectStmt).SetKeyspace("gocql_test_stmt_cache").Scan(&scannedID) + require.NoError(t, err) + require.Equal(t, 2, scannedID) + + session.Query("DROP KEYSPACE IF EXISTS gocql_test_stmt_cache").Exec() +} + +func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }` + + err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_routing_key_cache")) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_routing_key_cache.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + getRoutingKeyInfo := func(key string) *routingKeyInfo { + t.Helper() + session.routingKeyInfoCache.mu.Lock() + value, _ := session.routingKeyInfoCache.lru.Get(key) + session.routingKeyInfoCache.mu.Unlock() + + inflight := value.(*inflightCachedEntry) + return inflight.value.(*routingKeyInfo) + } + + const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)" + + // Running batch in default ks gocql_test + b1 := session.NewBatch(LoggedBatch) + b1.Query(insertQuery, 1) + _, err = b1.GetRoutingKey() + require.NoError(t, err) + + // Ensuring that the cache contains the query with default ks + routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt) + require.Equal(t, "gocql_test", routingKeyInfo1.keyspace) + + // Running batch in gocql_test_routing_key_cache ks + b2 := session.NewBatch(LoggedBatch) + b2.SetKeyspace("gocql_test_routing_key_cache") + b2.Query(insertQuery, 2) + _, err = b2.GetRoutingKey() + require.NoError(t, err) + + // Ensuring that the cache contains the query with gocql_test_routing_key_cache ks + routingKeyInfo2 := getRoutingKeyInfo("gocql_test_routing_key_cache" + b2.Entries[0].Stmt) + require.Equal(t, "gocql_test_routing_key_cache", routingKeyInfo2.keyspace) + + const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks WHERE id=?" + + // Running query in default ks gocql_test + q1 := session.Query(selectStmt, 1) + _, err = q1.GetRoutingKey() + require.NoError(t, err) + require.Equal(t, "gocql_test", q1.routingInfo.keyspace) + + // Running query in gocql_test_routing_key_cache ks + q2 := session.Query(selectStmt, 1) + _, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey() + require.NoError(t, err) + require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace) + + session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec() +} diff --git a/conn.go b/conn.go index 7c2728ba0..7a50a732b 100644 --- a/conn.go +++ b/conn.go @@ -1410,8 +1410,8 @@ type inflightPrepare struct { preparedStatment *preparedStatment } -func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) +func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), keyspace, stmt) flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { flight := &inflightPrepare{ done: make(chan struct{}), @@ -1486,10 +1486,6 @@ func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tra } } -func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { - return c.prepareStatementForKeyspace(ctx, stmt, tracer, c.currentKeyspace) -} - func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { if named, ok := value.(*namedValue); ok { dst.name = named.name @@ -1531,6 +1527,13 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.nowInSeconds = qry.nowInSecondsValue } + // If a keyspace for the qry is overriden, + // then we should use it to create stmt cache key + usedKeyspace := c.currentKeyspace + if qry.keyspace != "" { + usedKeyspace = qry.keyspace + } + var ( frame frameBuilder info *preparedStatment @@ -1539,7 +1542,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if !qry.skipPrepare && qry.shouldPrepare() { // Prepare all DML queries. Other queries can not be prepared. var err error - info, err = c.prepareStatementForKeyspace(ctx, qry.stmt, qry.trace, qry.keyspace) + info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace) if err != nil { return &Iter{err: err} } @@ -1616,7 +1619,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // If a RESULT/Rows message reports // changed resultset metadata with the Metadata_changed flag, the reported new // resultset metadata must be used in subsequent executions - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey) if ok { newInflight := &inflightPrepare{ @@ -1685,7 +1688,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) return c.executeQuery(ctx, qry) case error: @@ -1767,6 +1770,11 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { req.nowInSeconds = batch.nowInSeconds } + usedKeyspace := c.currentKeyspace + if batch.keyspace != "" { + usedKeyspace = batch.keyspace + } + stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { @@ -1774,7 +1782,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { b := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatementForKeyspace(batch.Context(), entry.Stmt, batch.trace, batch.keyspace) + info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace) if err != nil { return &Iter{err: err} } @@ -1836,7 +1844,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { - key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) + key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } return c.executeBatch(ctx, batch) diff --git a/session.go b/session.go index 2175c28e2..5841efc49 100644 --- a/session.go +++ b/session.go @@ -591,11 +591,20 @@ func (s *Session) getConn() *Conn { return nil } -// returns routing key indexes and type info -func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) { +// Returns routing key indexes and type info. +// If keyspace == "" it uses the keyspace which is specified in Cluster.Keyspace +func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace string) (*routingKeyInfo, error) { + if keyspace == "" { + keyspace = s.cfg.Keyspace + } + + routingKeyInfoCacheKey := keyspace + stmt + s.routingKeyInfoCache.mu.Lock() - entry, cached := s.routingKeyInfoCache.lru.Get(stmt) + // Using here keyspace + stmt as a cache key because + // the query keyspace could be overridden via SetKeyspace + entry, cached := s.routingKeyInfoCache.lru.Get(routingKeyInfoCacheKey) if cached { // done accessing the cache s.routingKeyInfoCache.mu.Unlock() @@ -619,7 +628,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI inflight := new(inflightCachedEntry) inflight.wg.Add(1) defer inflight.wg.Done() - s.routingKeyInfoCache.lru.Add(stmt, inflight) + s.routingKeyInfoCache.lru.Add(routingKeyInfoCacheKey, inflight) s.routingKeyInfoCache.mu.Unlock() var ( @@ -635,7 +644,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } // get the query info for the statement - info, inflight.err = conn.prepareStatement(ctx, stmt, nil) + info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace) if inflight.err != nil { // don't cache this error s.routingKeyInfoCache.Remove(stmt) @@ -651,7 +660,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } table := info.request.table - keyspace := info.request.keyspace + keyspace = info.request.keyspace if len(info.request.pkeyColumns) > 0 { // proto v4 dont need to calculate primary key columns @@ -1177,7 +1186,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) { } // try to determine the routing key - routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt) + routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt, q.keyspace) if err != nil { return nil, err } @@ -2009,7 +2018,7 @@ func (b *Batch) GetRoutingKey() ([]byte, error) { return nil, nil } // try to determine the routing key - routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt) + routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.keyspace) if err != nil { return nil, err }