diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d2044a0c0..7b5f08a9e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,7 +37,7 @@ jobs: go: [ '1.22', '1.23' ] cassandra_version: [ '4.0.13', '4.1.6' ] auth: [ "false" ] - compressor: [ "snappy" ] + compressor: [ "lz4" ] tags: [ "cassandra", "integration", "ccm" ] steps: - uses: actions/checkout@v2 @@ -101,7 +101,7 @@ jobs: ccm status ccm node1 nodetool status - args="-gocql.timeout=60s -runssl -proto=4 -rf=3 -clusterSize=3 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." + args="-gocql.timeout=60s -runssl -proto=5 -rf=3 -clusterSize=3 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." echo "args=$args" >> $GITHUB_ENV echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV @@ -127,7 +127,7 @@ jobs: matrix: go: [ '1.22', '1.23' ] cassandra_version: [ '4.0.13' ] - compressor: [ "snappy" ] + compressor: [ "lz4" ] tags: [ "integration" ] steps: @@ -190,7 +190,7 @@ jobs: ccm status ccm node1 nodetool status - args="-gocql.timeout=60s -runssl -proto=4 -rf=3 -clusterSize=1 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." + args="-gocql.timeout=60s -runssl -proto=5 -rf=3 -clusterSize=1 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." echo "args=$args" >> $GITHUB_ENV echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV diff --git a/batch_test.go b/batch_test.go index 25f8c8364..7074628e6 100644 --- a/batch_test.go +++ b/batch_test.go @@ -28,6 +28,7 @@ package gocql import ( + "github.com/stretchr/testify/require" "testing" "time" ) @@ -84,3 +85,84 @@ func TestBatch_WithTimestamp(t *testing.T) { t.Errorf("got ts %d, expected %d", storedTs, micros) } } + +func TestBatch_WithNowInSeconds(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("Batch now in seconds are only available on protocol >= 5") + } + + if err := createTable(session, `CREATE TABLE IF NOT EXISTS batch_now_in_seconds (id int primary key, val text)`); err != nil { + t.Fatal(err) + } + + b := session.NewBatch(LoggedBatch) + b.WithNowInSeconds(0) + b.Query("INSERT INTO batch_now_in_seconds (id, val) VALUES (?, ?) USING TTL 20", 1, "val") + if err := session.ExecuteBatch(b); err != nil { + t.Fatal(err) + } + + var remainingTTL int + err := session.Query(`SELECT TTL(val) FROM batch_now_in_seconds WHERE id = ?`, 1). + WithNowInSeconds(10). + Scan(&remainingTTL) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, remainingTTL, 10) +} + +func TestBatch_SetKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("keyspace for BATCH message is not supported in protocol < 5") + } + + const keyspaceStmt = ` + CREATE KEYSPACE IF NOT EXISTS gocql_keyspace_override_test + WITH replication = { + 'class': 'SimpleStrategy', + 'replication_factor': '1' + }; +` + + err := session.Query(keyspaceStmt).Exec() + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_keyspace_override_test.batch_keyspace(id int, value text, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + ids := []int{1, 2} + texts := []string{"val1", "val2"} + + b := session.NewBatch(LoggedBatch).SetKeyspace("gocql_keyspace_override_test") + b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[0], texts[0]) + b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[1], texts[1]) + err = session.ExecuteBatch(b) + if err != nil { + t.Fatal(err) + } + + var ( + id int + text string + ) + + iter := session.Query("SELECT * FROM gocql_keyspace_override_test.batch_keyspace").Iter() + defer iter.Close() + + for i := 0; iter.Scan(&id, &text); i++ { + require.Equal(t, id, ids[i]) + require.Equal(t, text, texts[i]) + } +} diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..133e26167 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,6 +32,7 @@ import ( "context" "errors" "fmt" + "github.com/stretchr/testify/require" "io" "math" "math/big" @@ -1482,7 +1483,7 @@ func TestQueryInfo(t *testing.T) { defer session.Close() conn := getRandomConn(t, session) - info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil) + info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil, conn.currentKeyspace) if err != nil { t.Fatalf("Failed to execute query for preparing statement: %v", err) @@ -2601,7 +2602,7 @@ func TestRoutingKey(t *testing.T) { t.Fatalf("failed to create table with error '%v'", err) } - routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -2625,7 +2626,7 @@ func TestRoutingKey(t *testing.T) { } // verify the cache is working - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -2659,7 +2660,7 @@ func TestRoutingKey(t *testing.T) { t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey) } - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?") + routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } @@ -3288,3 +3289,460 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestQuery_WithNowInSeconds(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("Query now in seconds are only available on protocol >= 5") + } + + if err := createTable(session, `CREATE TABLE IF NOT EXISTS query_now_in_seconds (id int primary key, val text)`); err != nil { + t.Fatal(err) + } + + err := session.Query("INSERT INTO query_now_in_seconds (id, val) VALUES (?, ?) USING TTL 20", 1, "val"). + WithNowInSeconds(int(0)). + Exec() + if err != nil { + t.Fatal(err) + } + + var remainingTTL int + err = session.Query(`SELECT TTL(val) FROM query_now_in_seconds WHERE id = ?`, 1). + WithNowInSeconds(10). + Scan(&remainingTTL) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, remainingTTL, 10) +} + +func TestQuery_SetKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("keyspace for QUERY message is not supported in protocol < 5") + } + + const keyspaceStmt = ` + CREATE KEYSPACE IF NOT EXISTS gocql_query_keyspace_override_test + WITH replication = { + 'class': 'SimpleStrategy', + 'replication_factor': '1' + }; +` + + err := session.Query(keyspaceStmt).Exec() + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_query_keyspace_override_test.query_keyspace(id int, value text, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + expectedID := 1 + expectedText := "text" + + // Testing PREPARE message + err = session.Query("INSERT INTO gocql_query_keyspace_override_test.query_keyspace (id, value) VALUES (?, ?)", expectedID, expectedText).Exec() + if err != nil { + t.Fatal(err) + } + + var ( + id int + text string + ) + + q := session.Query("SELECT * FROM gocql_query_keyspace_override_test.query_keyspace"). + SetKeyspace("gocql_query_keyspace_override_test") + err = q.Scan(&id, &text) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, expectedID, id) + require.Equal(t, expectedText, text) + + // Testing QUERY message + id = 0 + text = "" + + q = session.Query("SELECT * FROM gocql_query_keyspace_override_test.query_keyspace"). + SetKeyspace("gocql_query_keyspace_override_test") + q.skipPrepare = true + err = q.Scan(&id, &text) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, expectedID, id) + require.Equal(t, expectedText, text) +} + +func TestLargeSizeQuery(t *testing.T) { + // TestLargeSizeQuery runs a query bigger than the max allowed size of the payload of a frame, + // so it should be sent as 2 different frames where each contains a self-contained bit set to zero. + + session := createSession(t) + defer session.Close() + + if err := createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.large_size_query(id int, text_col text, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + longString := strings.Repeat("a", 500_000) + + err := session.Query("INSERT INTO gocql_test.large_size_query (id, text_col) VALUES (?, ?)", "1", longString).Exec() + if err != nil { + t.Fatal(err) + } + + var result string + err = session.Query("SELECT text_col FROM gocql_test.large_size_query").Scan(&result) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, longString, result) +} + +func TestQueryCompressionNotWorthIt(t *testing.T) { + // TestQueryCompressionNotWorthIt runs a query that is not likely to be compressed efficiently + // (uncompressed payload size > compressed payload size). + // So, it should send a Compressed Frame where: + // 1. Compressed length is set to the length of the uncompressed payload; + // 2. Uncompressed length is set to zero; + // 3. Payload is the uncompressed payload. + + session := createSession(t) + defer session.Close() + + if err := createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.compression_now_worth_it(id int, text_col text, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + str := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&*()_+" + err := session.Query("INSERT INTO gocql_test.large_size_query (id, text_col) VALUES (?, ?)", "1", str).Exec() + if err != nil { + t.Fatal(err) + } + + var result string + err = session.Query("SELECT text_col FROM gocql_test.large_size_query").Scan(&result) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, str, result) +} + +func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { + // This test ensures that the whole Metadata_changed flow + // is handled properly. + // + // To trigger C* to return Metadata_changed we should do: + // 1. Create a table + // 2. Prepare stmt which uses the created table + // 3. Change the table schema in order to affect prepared stmt (e.g. add a column) + // 4. Execute prepared stmt. As a result C* should return RESULT/ROWS response with + // Metadata_changed flag, new metadata id and updated metadata resultset. + // + // The driver should handle this by updating its prepared statement inside the cache + // when it receives RESULT/ROWS with Metadata_changed flag + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("Metadata_changed mechanism is only available in proto > 4") + } + + if err := createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.metadata_changed(id int, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + type record struct { + id int + newCol int + } + + firstRecord := record{ + id: 1, + } + err := session.Query("INSERT INTO gocql_test.metadata_changed (id) VALUES (?)", firstRecord.id).Exec() + if err != nil { + t.Fatal(err) + } + + // We have to specify conn for all queries to ensure that + // all queries are running on the same node + conn := session.getConn() + + const selectStmt = "SELECT * FROM gocql_test.metadata_changed" + queryBeforeTableAltering := session.Query(selectStmt) + queryBeforeTableAltering.conn = conn + row := make(map[string]interface{}) + err = queryBeforeTableAltering.MapScan(row) + if err != nil { + t.Fatal(err) + } + + require.Len(t, row, 1, "Expected to retrieve a single column") + require.Equal(t, 1, row["id"]) + + stmtCacheKey := session.stmtsLRU.keyFor(conn.host.HostID(), conn.currentKeyspace, queryBeforeTableAltering.stmt) + inflight, _ := session.stmtsLRU.get(stmtCacheKey) + preparedStatementBeforeTableAltering := inflight.preparedStatment + + // Changing table schema in order to cause C* to return RESULT/ROWS Metadata_changed + alteringTableQuery := session.Query("ALTER TABLE gocql_test.metadata_changed ADD new_col int") + alteringTableQuery.conn = conn + err = alteringTableQuery.Exec() + if err != nil { + t.Fatal(err) + } + + secondRecord := record{ + id: 2, + newCol: 10, + } + err = session.Query("INSERT INTO gocql_test.metadata_changed (id, new_col) VALUES (?, ?)", secondRecord.id, secondRecord.newCol). + Exec() + if err != nil { + t.Fatal(err) + } + + // Handles result from iter and ensures integrity of the result, + // closes iter and handles error + handleRows := func(iter *Iter) { + t.Helper() + + var scannedID int + var scannedNewCol *int // to perform null values + + // when the driver handling null values during unmarshalling + // it sets to dest type its zero value, which is (*int)(nil) for this case + var nilIntPtr *int + + // Scanning first row + if iter.Scan(&scannedID, &scannedNewCol) { + require.Equal(t, firstRecord.id, scannedID) + require.Equal(t, nilIntPtr, scannedNewCol) + } + + // Scanning second row + if iter.Scan(&scannedID, &scannedNewCol) { + require.Equal(t, secondRecord.id, scannedID) + require.Equal(t, &secondRecord.newCol, scannedNewCol) + } + + err := iter.Close() + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + t.Fatal("It is likely failed due deadlock") + } + t.Fatal(err) + } + } + + // Expecting C* will return RESULT/ROWS Metadata_changed + // and it will be properly handled + queryAfterTableAltering := session.Query(selectStmt) + queryAfterTableAltering.conn = conn + iter := queryAfterTableAltering.Iter() + handleRows(iter) + + // Ensuring if cache contains updated prepared statement + inflight, _ = session.stmtsLRU.get(stmtCacheKey) + preparedStatementAfterTableAltering := inflight.preparedStatment + require.NotEqual(t, preparedStatementBeforeTableAltering.resultMetadataID, preparedStatementAfterTableAltering.resultMetadataID) + require.NotEqual(t, preparedStatementBeforeTableAltering.response, preparedStatementAfterTableAltering.response) + + // FORCE SEND OLD RESULT METADATA ID (https://issues.apache.org/jira/browse/CASSANDRA-20028) + closedCh := make(chan struct{}) + close(closedCh) + session.stmtsLRU.add(stmtCacheKey, &inflightPrepare{ + done: closedCh, + err: nil, + preparedStatment: preparedStatementBeforeTableAltering, + }) + + // Running query with timeout to ensure there is no deadlocks. + // However, it doesn't 100% proves that there is a deadlock... + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + queryAfterTableAltering2 := session.Query(selectStmt).WithContext(ctx) + queryAfterTableAltering2.conn = conn + iter = queryAfterTableAltering2.Iter() + handleRows(iter) + err = iter.Close() + + inflight, _ = session.stmtsLRU.get(stmtCacheKey) + preparedStatementAfterTableAltering2 := inflight.preparedStatment + require.NotEqual(t, preparedStatementBeforeTableAltering.resultMetadataID, preparedStatementAfterTableAltering2.resultMetadataID) + require.NotEqual(t, preparedStatementBeforeTableAltering.response, preparedStatementAfterTableAltering2.response) + + require.Equal(t, preparedStatementAfterTableAltering.resultMetadataID, preparedStatementAfterTableAltering2.resultMetadataID) + require.NotEqual(t, preparedStatementAfterTableAltering.response, preparedStatementAfterTableAltering2.response) // METADATA_CHANGED flag + require.True(t, preparedStatementAfterTableAltering2.response.flags&flagMetaDataChanged != 0) + + // Executing prepared stmt and expecting that C* won't return + // Metadata_changed because the table is not being changed. + queryAfterTableAltering3 := session.Query(selectStmt).WithContext(ctx) + queryAfterTableAltering3.conn = conn + iter = queryAfterTableAltering2.Iter() + handleRows(iter) + + // Ensuring metadata of prepared stmt is not changed + inflight, _ = session.stmtsLRU.get(stmtCacheKey) + preparedStatementAfterTableAltering3 := inflight.preparedStatment + require.Equal(t, preparedStatementAfterTableAltering2.resultMetadataID, preparedStatementAfterTableAltering3.resultMetadataID) + require.Equal(t, preparedStatementAfterTableAltering2.response, preparedStatementAfterTableAltering3.response) +} + +func TestStmtCacheUsesOverriddenKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("This tests only runs on proto > 4 due SetKeyspace availability") + } + + const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }` + + err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_stmt_cache")) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_stmt_cache.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + const insertQuery = "INSERT INTO stmt_cache_uses_overridden_ks (id) VALUES (?)" + + // Inserting data via Batch to ensure that batches + // properly accounts for keyspace overriding + b1 := session.NewBatch(LoggedBatch) + b1.Query(insertQuery, 1) + err = session.ExecuteBatch(b1) + require.NoError(t, err) + + b2 := session.NewBatch(LoggedBatch) + b2.SetKeyspace("gocql_test_stmt_cache") + b2.Query(insertQuery, 2) + err = session.ExecuteBatch(b2) + require.NoError(t, err) + + var scannedID int + + const selectStmt = "SELECT * FROM stmt_cache_uses_overridden_ks" + + // By default in our test suite session uses gocql_test ks + err = session.Query(selectStmt).Scan(&scannedID) + require.NoError(t, err) + require.Equal(t, 1, scannedID) + + scannedID = 0 + err = session.Query(selectStmt).SetKeyspace("gocql_test_stmt_cache").Scan(&scannedID) + require.NoError(t, err) + require.Equal(t, 2, scannedID) + + session.Query("DROP KEYSPACE IF EXISTS gocql_test_stmt_cache").Exec() +} + +func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("This tests only runs on proto > 4 due SetKeyspace availability") + } + + const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }` + + err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_routing_key_cache")) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_routing_key_cache.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + getRoutingKeyInfo := func(key string) *routingKeyInfo { + t.Helper() + session.routingKeyInfoCache.mu.Lock() + value, _ := session.routingKeyInfoCache.lru.Get(key) + session.routingKeyInfoCache.mu.Unlock() + + inflight := value.(*inflightCachedEntry) + return inflight.value.(*routingKeyInfo) + } + + const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)" + + // Running batch in default ks gocql_test + b1 := session.NewBatch(LoggedBatch) + b1.Query(insertQuery, 1) + _, err = b1.GetRoutingKey() + require.NoError(t, err) + + // Ensuring that the cache contains the query with default ks + routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt) + require.Equal(t, "gocql_test", routingKeyInfo1.keyspace) + + // Running batch in gocql_test_routing_key_cache ks + b2 := session.NewBatch(LoggedBatch) + b2.SetKeyspace("gocql_test_routing_key_cache") + b2.Query(insertQuery, 2) + _, err = b2.GetRoutingKey() + require.NoError(t, err) + + // Ensuring that the cache contains the query with gocql_test_routing_key_cache ks + routingKeyInfo2 := getRoutingKeyInfo("gocql_test_routing_key_cache" + b2.Entries[0].Stmt) + require.Equal(t, "gocql_test_routing_key_cache", routingKeyInfo2.keyspace) + + const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks WHERE id=?" + + // Running query in default ks gocql_test + q1 := session.Query(selectStmt, 1) + _, err = q1.GetRoutingKey() + require.NoError(t, err) + require.Equal(t, "gocql_test", q1.routingInfo.keyspace) + + // Running query in gocql_test_routing_key_cache ks + q2 := session.Query(selectStmt, 1) + _, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey() + require.NoError(t, err) + require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace) + + session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec() +} diff --git a/common_test.go b/common_test.go index a5edb03c6..cb230ede5 100644 --- a/common_test.go +++ b/common_test.go @@ -34,6 +34,9 @@ import ( "sync" "testing" "time" + + "github.com/gocql/gocql/lz4" + "github.com/gocql/gocql/snappy" ) var ( @@ -110,7 +113,9 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { switch *flagCompressTest { case "snappy": - cluster.Compressor = &SnappyCompressor{} + cluster.Compressor = &snappy.SnappyCompressor{} + case "lz4": + cluster.Compressor = lz4.LZ4Compressor{} case "": default: panic("invalid compressor: " + *flagCompressTest) diff --git a/compressor.go b/compressor.go index f3d451a9f..a4c305b7e 100644 --- a/compressor.go +++ b/compressor.go @@ -24,29 +24,24 @@ package gocql -import ( - "github.com/golang/snappy" -) - type Compressor interface { Name() string - Encode(data []byte) ([]byte, error) - Decode(data []byte) ([]byte, error) -} -// SnappyCompressor implements the Compressor interface and can be used to -// compress incoming and outgoing frames. The snappy compression algorithm -// aims for very high speeds and reasonable compression. -type SnappyCompressor struct{} + // AppendCompressedWithLength compresses src bytes, appends the length of the compressed bytes to dst + // and then appends the compressed bytes to dst. + // It returns a new byte slice that is the result of the append operation. + AppendCompressedWithLength(dst, src []byte) ([]byte, error) -func (s SnappyCompressor) Name() string { - return "snappy" -} + // AppendDecompressedWithLength reads the length of the decompressed bytes from src, + // decompressed bytes from src and appends the decompressed bytes to dst. + // It returns a new byte slice that is the result of the append operation. + AppendDecompressedWithLength(dst, src []byte) ([]byte, error) -func (s SnappyCompressor) Encode(data []byte) ([]byte, error) { - return snappy.Encode(nil, data), nil -} + // AppendCompressed compresses src bytes and appends the compressed bytes to dst. + // It returns a new byte slice that is the result of the append operation. + AppendCompressed(dst, src []byte) ([]byte, error) -func (s SnappyCompressor) Decode(data []byte) ([]byte, error) { - return snappy.Decode(nil, data) + // AppendDecompressed decompresses bytes from src and appends the decompressed bytes to dst. + // It returns a new byte slice that is the result of the append operation. + AppendDecompressed(dst, src []byte, decompressedLength uint32) ([]byte, error) } diff --git a/conn.go b/conn.go index 3daca6250..1fd3ea3cf 100644 --- a/conn.go +++ b/conn.go @@ -26,6 +26,7 @@ package gocql import ( "bufio" + "bytes" "context" "crypto/tls" "errors" @@ -186,11 +187,9 @@ var TimeoutLimit int64 = 0 // queries, but users are usually advised to use a more reliable, higher // level API. type Conn struct { - conn net.Conn - r *bufio.Reader - w contextWriter + r ConnReader + w contextWriter - timeout time.Duration writeTimeout time.Duration cfg *ConnConfig frameObserver FrameHeaderObserver @@ -268,8 +267,10 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * ctx, cancel := context.WithCancel(ctx) c := &Conn{ - conn: dialedHost.Conn, - r: bufio.NewReader(dialedHost.Conn), + r: &connReader{ + conn: dialedHost.Conn, + r: bufio.NewReader(dialedHost.Conn), + }, cfg: cfg, calls: make(map[int]*callReq), version: uint8(cfg.ProtoVersion), @@ -319,16 +320,16 @@ func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error { conn: c, } - c.timeout = c.cfg.ConnectTimeout + c.r.SetTimeout(c.cfg.ConnectTimeout) if err := startup.setupConn(ctx); err != nil { return err } - c.timeout = c.cfg.Timeout + c.r.SetTimeout(c.cfg.Timeout) // dont coalesce startup frames if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce && !dialedHost.DisableCoalesce { - c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done()) + c.w = newWriteCoalescer(dialedHost.Conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done()) } go c.serve(ctx) @@ -341,29 +342,6 @@ func (c *Conn) Write(p []byte) (n int, err error) { return c.w.writeContext(context.Background(), p) } -func (c *Conn) Read(p []byte) (n int, err error) { - const maxAttempts = 5 - - for i := 0; i < maxAttempts; i++ { - var nn int - if c.timeout > 0 { - c.conn.SetReadDeadline(time.Now().Add(c.timeout)) - } - - nn, err = io.ReadFull(c.r, p[n:]) - n += nn - if err == nil { - break - } - - if verr, ok := err.(net.Error); !ok || !verr.Temporary() { - break - } - } - - return -} - type startupCoordinator struct { conn *Conn frameTicker chan struct{} @@ -371,17 +349,26 @@ type startupCoordinator struct { func (s *startupCoordinator) setupConn(ctx context.Context) error { var cancel context.CancelFunc - if s.conn.timeout > 0 { - ctx, cancel = context.WithTimeout(ctx, s.conn.timeout) + if s.conn.r.GetTimeout() > 0 { + ctx, cancel = context.WithTimeout(ctx, s.conn.r.GetTimeout()) } else { ctx, cancel = context.WithCancel(ctx) } defer cancel() + // Only for proto v5+. + // Indicates if STARTUP has been completed. + // github.com/apache/cassandra/blob/trunk/doc/native_protocol_v5.spec + // 2.3.1 Initial Handshake + // In order to support both v5 and earlier formats, the v5 framing format is not + // applied to message exchanges before an initial handshake is completed. + startupCompleted := &atomic.Bool{} + startupCompleted.Store(false) + startupErr := make(chan error) go func() { for range s.frameTicker { - err := s.conn.recv(ctx) + err := s.conn.recv(ctx, startupCompleted.Load()) if err != nil { select { case startupErr <- err: @@ -395,7 +382,7 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error { go func() { defer close(s.frameTicker) - err := s.options(ctx) + err := s.options(ctx, startupCompleted) select { case startupErr <- err: case <-ctx.Done(): @@ -414,14 +401,14 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error { return nil } -func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (frame, error) { +func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder, startupCompleted *atomic.Bool) (frame, error) { select { case s.frameTicker <- struct{}{}: case <-ctx.Done(): return nil, ctx.Err() } - framer, err := s.conn.exec(ctx, frame, nil) + framer, err := s.conn.execInternal(ctx, frame, nil, startupCompleted.Load()) if err != nil { return nil, err } @@ -429,8 +416,8 @@ func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (fra return framer.parseFrame() } -func (s *startupCoordinator) options(ctx context.Context) error { - frame, err := s.write(ctx, &writeOptionsFrame{}) +func (s *startupCoordinator) options(ctx context.Context, startupCompleted *atomic.Bool) error { + frame, err := s.write(ctx, &writeOptionsFrame{}, startupCompleted) if err != nil { return err } @@ -440,10 +427,10 @@ func (s *startupCoordinator) options(ctx context.Context) error { return NewErrProtocol("Unknown type of response to startup frame: %T", frame) } - return s.startup(ctx, supported.supported) + return s.startup(ctx, supported.supported, startupCompleted) } -func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error { +func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string, startupCompleted *atomic.Bool) error { m := map[string]string{ "CQL_VERSION": s.conn.cfg.CQLVersion, "DRIVER_NAME": driverName, @@ -465,7 +452,7 @@ func (s *startupCoordinator) startup(ctx context.Context, supported map[string][ } } - frame, err := s.write(ctx, &writeStartupFrame{opts: m}) + frame, err := s.write(ctx, &writeStartupFrame{opts: m}, startupCompleted) if err != nil { return err } @@ -474,15 +461,19 @@ func (s *startupCoordinator) startup(ctx context.Context, supported map[string][ case error: return v case *readyFrame: + // Startup is successfully completed, so we could use Native Protocol 5 + startupCompleted.Store(true) return nil case *authenticateFrame: - return s.authenticateHandshake(ctx, v) + // Startup is successfully completed, so we could use Native Protocol 5 + startupCompleted.Store(true) + return s.authenticateHandshake(ctx, v, startupCompleted) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) } } -func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error { +func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, startupCompleted *atomic.Bool) error { if s.conn.auth == nil { return fmt.Errorf("authentication required (using %q)", authFrame.class) } @@ -494,7 +485,7 @@ func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFram req := &writeAuthResponseFrame{data: resp} for { - frame, err := s.write(ctx, req) + frame, err := s.write(ctx, req, startupCompleted) if err != nil { return err } @@ -563,7 +554,7 @@ func (c *Conn) closeWithError(err error) { // if error was nil then unblock the quit channel c.cancel() - cerr := c.close() + cerr := c.r.Close() if err != nil { c.errorHandler.HandleError(c, err, true) @@ -573,10 +564,6 @@ func (c *Conn) closeWithError(err error) { } } -func (c *Conn) close() error { - return c.conn.Close() -} - func (c *Conn) Close() { c.closeWithError(nil) } @@ -587,14 +574,14 @@ func (c *Conn) Close() { func (c *Conn) serve(ctx context.Context) { var err error for err == nil { - err = c.recv(ctx) + err = c.recv(ctx, true) } c.closeWithError(err) } -func (c *Conn) discardFrame(head frameHeader) error { - _, err := io.CopyN(ioutil.Discard, c, int64(head.length)) +func (c *Conn) discardFrame(r io.Reader, head frameHeader) error { + _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) if err != nil { return err } @@ -659,18 +646,28 @@ func (c *Conn) heartBeat(ctx context.Context) { } } -func (c *Conn) recv(ctx context.Context) error { +func (c *Conn) recv(ctx context.Context, startupCompleted bool) error { + // If startup is completed and native proto 5+ is set up then we should + // unwrap payload from compressed/uncompressed frame + if startupCompleted && c.version > protoVersion4 { + return c.recvSegment(ctx) + } + + return c.processFrame(ctx, c.r) +} + +func (c *Conn) processFrame(ctx context.Context, r io.Reader) error { // not safe for concurrent reads // read a full header, ignore timeouts, as this is being ran in a loop // TODO: TCP level deadlines? or just query level deadlines? - if c.timeout > 0 { - c.conn.SetReadDeadline(time.Time{}) + if c.r.GetTimeout() > 0 { + c.r.SetReadDeadline(time.Time{}) } headStartTime := time.Now() // were just reading headers over and over and copy bodies - head, err := readHeader(c.r, c.headerBuf[:]) + head, err := readHeader(r, c.headerBuf[:]) headEndTime := time.Now() if err != nil { return err @@ -694,7 +691,7 @@ func (c *Conn) recv(ctx context.Context) error { } else if head.stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently framer := newFramer(c.compressor, c.version) - if err := framer.readFrame(c, &head); err != nil { + if err := framer.readFrame(r, &head); err != nil { return err } go c.session.handleEvent(framer) @@ -703,7 +700,7 @@ func (c *Conn) recv(ctx context.Context) error { // reserved stream that we dont use, probably due to a protocol error // or a bug in Cassandra, this should be an error, parse it and return. framer := newFramer(c.compressor, c.version) - if err := framer.readFrame(c, &head); err != nil { + if err := framer.readFrame(r, &head); err != nil { return err } @@ -727,14 +724,14 @@ func (c *Conn) recv(ctx context.Context) error { c.mu.Unlock() if call == nil || !ok { c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) - return c.discardFrame(head) + return c.discardFrame(r, head) } else if head.stream != call.streamID { panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) } framer := newFramer(c.compressor, c.version) - err = framer.readFrame(c, &head) + err = framer.readFrame(r, &head) if err != nil { // only net errors should cause the connection to be closed. Though // cassandra returning corrupt frames will be returned here as well. @@ -777,6 +774,172 @@ func (c *Conn) handleTimeout() { } } +func (c *Conn) recvSegment(ctx context.Context) error { + var ( + frame []byte + isSelfContained bool + err error + ) + + // Read frame based on compression + if c.compressor != nil { + frame, isSelfContained, err = readCompressedSegment(c.r, c.compressor) + } else { + frame, isSelfContained, err = readUncompressedSegment(c.r) + } + if err != nil { + return err + } + + if isSelfContained { + return c.processAllFramesInSegment(ctx, bytes.NewReader(frame)) + } + + head, err := readHeader(bytes.NewReader(frame), c.headerBuf[:]) + if err != nil { + return err + } + + const frameHeaderLength = 9 + buf := bytes.NewBuffer(make([]byte, 0, head.length+frameHeaderLength)) + buf.Write(frame) + + // Computing how many bytes of message left to read + bytesToRead := head.length - len(frame) + frameHeaderLength + + err = c.recvPartialFrames(buf, bytesToRead) + if err != nil { + return err + } + + return c.processFrame(ctx, buf) +} + +// recvPartialFrames reads proto v5 segments from Conn.r and writes decoded partial frames to dst. +// It reads data until the bytesToRead is reached. +// If Conn.compressor is not nil, it processes Compressed Format segments. +func (c *Conn) recvPartialFrames(dst *bytes.Buffer, bytesToRead int) error { + var ( + read int + frame []byte + isSelfContained bool + err error + ) + + for read != bytesToRead { + // Read frame based on compression + if c.compressor != nil { + frame, isSelfContained, err = readCompressedSegment(c.r, c.compressor) + } else { + frame, isSelfContained, err = readUncompressedSegment(c.r) + } + if err != nil { + return fmt.Errorf("gocql: failed to read non self-contained frame: %w", err) + } + + if isSelfContained { + return fmt.Errorf("gocql: received self-contained segment, but expected not") + } + + if totalLength := dst.Len() + len(frame); totalLength > dst.Cap() { + return fmt.Errorf("gocql: expected partial frame of length %d, got %d", dst.Cap(), totalLength) + } + + // Write the frame to the destination writer + n, _ := dst.Write(frame) + read += n + } + + return nil +} + +func (c *Conn) processAllFramesInSegment(ctx context.Context, r *bytes.Reader) error { + var err error + for r.Len() > 0 && err == nil { + err = c.processFrame(ctx, r) + } + + return err +} + +// ConnReader is like net.Conn but also allows to set timeout duration. +type ConnReader interface { + net.Conn + + // SetTimeout sets timeout duration for reading data form conn + SetTimeout(timeout time.Duration) + + // GetTimeout returns timeout duration + GetTimeout() time.Duration +} + +// connReader implements ConnReader. +// It retries to read data up to 5 times or returns error. +type connReader struct { + conn net.Conn + r *bufio.Reader + timeout time.Duration +} + +func (c *connReader) Read(p []byte) (n int, err error) { + const maxAttempts = 5 + + for i := 0; i < maxAttempts; i++ { + var nn int + if c.timeout > 0 { + c.conn.SetReadDeadline(time.Now().Add(c.timeout)) + } + + nn, err = io.ReadFull(c.r, p[n:]) + n += nn + if err == nil { + break + } + + if verr, ok := err.(net.Error); !ok || !verr.Temporary() { + break + } + } + + return +} + +func (c *connReader) Write(b []byte) (n int, err error) { + return c.conn.Write(b) +} + +func (c *connReader) Close() error { + return c.conn.Close() +} + +func (c *connReader) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *connReader) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *connReader) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *connReader) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *connReader) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *connReader) SetTimeout(timeout time.Duration) { + c.timeout = timeout +} + +func (c *connReader) GetTimeout() time.Duration { + return c.timeout +} + type callReq struct { // resp will receive the frame that was sent as a response to this stream. resp chan callResp @@ -1027,6 +1190,10 @@ func (c *Conn) addCall(call *callReq) error { } func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) { + return c.execInternal(ctx, req, tracer, true) +} + +func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer, startupCompleted bool) (*framer, error) { if ctxErr := ctx.Err(); ctxErr != nil { return nil, ctxErr } @@ -1086,7 +1253,14 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram return nil, err } - n, err := c.w.writeContext(ctx, framer.buf) + var n int + + if c.version > protoVersion4 && startupCompleted { + err = framer.prepareModernLayout() + } + if err == nil { + n, err = c.w.writeContext(ctx, framer.buf) + } if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout, close the timeout chan here. Im not entirely sure @@ -1115,7 +1289,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram } var timeoutCh <-chan time.Time - if c.timeout > 0 { + if timeout := c.r.GetTimeout(); timeout > 0 { if call.timer == nil { call.timer = time.NewTimer(0) <-call.timer.C @@ -1128,7 +1302,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram } } - call.timer.Reset(c.timeout) + call.timer.Reset(timeout) timeoutCh = call.timer.C } @@ -1223,9 +1397,10 @@ type StreamObserverContext interface { } type preparedStatment struct { - id []byte - request preparedMetadata - response resultMetadata + id []byte + resultMetadataID []byte + request preparedMetadata + response resultMetadata } type inflightPrepare struct { @@ -1235,8 +1410,8 @@ type inflightPrepare struct { preparedStatment *preparedStatment } -func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) +func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), keyspace, stmt) flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { flight := &inflightPrepare{ done: make(chan struct{}), @@ -1253,7 +1428,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) statement: stmt, } if c.version > protoVersion4 { - prep.keyspace = c.currentKeyspace + prep.keyspace = keyspace } // we won the race to do the load, if our context is canceled we shouldnt @@ -1284,7 +1459,8 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. - id: copyBytes(x.preparedID), + id: copyBytes(x.preparedID), + resultMetadataID: copyBytes(x.resultMetadataID), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. request: x.reqMeta, @@ -1347,7 +1523,15 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.pageSize = qry.pageSize } if c.version > protoVersion4 { - params.keyspace = c.currentKeyspace + params.keyspace = qry.keyspace + params.nowInSeconds = qry.nowInSecondsValue + } + + // If a keyspace for the qry is overriden, + // then we should use it to create stmt cache key + usedKeyspace := c.currentKeyspace + if qry.keyspace != "" { + usedKeyspace = qry.keyspace } var ( @@ -1358,7 +1542,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if !qry.skipPrepare && qry.shouldPrepare() { // Prepare all DML queries. Other queries can not be prepared. var err error - info, err = c.prepareStatement(ctx, qry.stmt, qry.trace) + info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace) if err != nil { return &Iter{err: err} } @@ -1394,14 +1578,18 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) frame = &writeExecuteFrame{ - preparedID: info.id, - params: params, - customPayload: qry.customPayload, + preparedID: info.id, + params: params, + customPayload: qry.customPayload, + resultMetadataID: info.resultMetadataID, } // Set "keyspace" and "table" property in the query if it is present in preparedMetadata qry.routingInfo.mu.Lock() qry.routingInfo.keyspace = info.request.keyspace + if info.request.keyspace == "" { + qry.routingInfo.keyspace = usedKeyspace + } qry.routingInfo.table = info.request.table qry.routingInfo.mu.Unlock() } else { @@ -1430,13 +1618,39 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { case *resultVoidFrame: return &Iter{framer: framer} case *resultRowsFrame: + if x.meta.newMetadataID != nil { + // If a RESULT/Rows message reports + // changed resultset metadata with the Metadata_changed flag, the reported new + // resultset metadata must be used in subsequent executions + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) + oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey) + if ok { + newInflight := &inflightPrepare{ + done: make(chan struct{}), + preparedStatment: &preparedStatment{ + id: oldInflight.preparedStatment.id, + resultMetadataID: x.meta.newMetadataID, + request: oldInflight.preparedStatment.request, + response: x.meta, + }, + } + // The driver should close this done to avoid deadlocks of + // other subsequent requests + close(newInflight.done) + c.session.stmtsLRU.add(stmtCacheKey, newInflight) + // Updating info to ensure the code is looking at the updated + // version of the prepared statement + info = newInflight.preparedStatment + } + } + iter := &Iter{ meta: x.meta, framer: framer, numRows: x.numRows, } - if params.skipMeta { + if x.meta.noMetaData() { if info != nil { iter.meta = info.response iter.meta.pagingState = copyBytes(x.meta.pagingState) @@ -1477,7 +1691,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) return c.executeQuery(ctx, qry) case error: @@ -1554,6 +1768,16 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { customPayload: batch.CustomPayload, } + if c.version > protoVersion4 { + req.keyspace = batch.keyspace + req.nowInSeconds = batch.nowInSeconds + } + + usedKeyspace := c.currentKeyspace + if batch.keyspace != "" { + usedKeyspace = batch.keyspace + } + stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { @@ -1561,7 +1785,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { b := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace) + info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace) if err != nil { return &Iter{err: err} } @@ -1623,7 +1847,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { - key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) + key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } return c.executeBatch(ctx, batch) diff --git a/conn_test.go b/conn_test.go index cab4c2f8f..f5735510f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -46,6 +46,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/gocql/gocql/internal/streams" ) @@ -707,12 +709,14 @@ func TestStream0(t *testing.T) { } conn := &Conn{ - r: bufio.NewReader(&buf), + r: &connReader{ + r: bufio.NewReader(&buf), + }, streams: streams.New(protoVersion4), logger: &defaultLogger{}, } - err := conn.recv(context.Background()) + err := conn.recv(context.Background(), false) if err == nil { t.Fatal("expected to get an error on stream 0") } else if !strings.HasPrefix(err.Error(), expErr) { @@ -1300,3 +1304,98 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { return framer, nil } + +func TestConnProcessAllFramesInSingleSegment(t *testing.T) { + server, client, err := tcpConnPair() + require.NoError(t, err) + + c := &Conn{ + r: &connReader{ + conn: server, + r: bufio.NewReader(server), + }, + calls: make(map[int]*callReq), + version: protoVersion5, + addr: server.RemoteAddr().String(), + streams: streams.New(protoVersion5), + isSchemaV2: true, + w: &deadlineContextWriter{ + w: server, + timeout: time.Second * 10, + semaphore: make(chan struct{}, 1), + quit: make(chan struct{}), + }, + logger: Logger, + writeTimeout: time.Second * 10, + } + + call1 := &callReq{ + timeout: make(chan struct{}), + streamID: 1, + resp: make(chan callResp), + } + + call2 := &callReq{ + timeout: make(chan struct{}), + streamID: 2, + resp: make(chan callResp), + } + + c.calls[1] = call1 + c.calls[2] = call2 + + req := writeQueryFrame{ + statement: "SELECT * FROM system.local", + params: queryParams{ + consistency: Quorum, + keyspace: "gocql_test", + }, + } + + framer1 := newFramer(nil, protoVersion5) + err = req.buildFrame(framer1, 1) + require.NoError(t, err) + + framer2 := newFramer(nil, protoVersion5) + err = req.buildFrame(framer2, 2) + require.NoError(t, err) + + go func() { + var buf []byte + buf = append(buf, framer1.buf...) + buf = append(buf, framer2.buf...) + + uncompressedSegment, err := newUncompressedSegment(buf, true) + require.NoError(t, err) + + _, err = client.Write(uncompressedSegment) + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- c.recvSegment(ctx) + }() + + go func() { + resp1 := <-call1.resp + close(call1.timeout) + // Skipping here the header of the frame because resp.framer contains already parsed header + // and resp.framer.buf contains frame body + require.Equal(t, framer1.buf[9:], resp1.framer.buf) + + resp2 := <-call2.resp + close(call2.timeout) + require.Equal(t, framer2.buf[9:], resp2.framer.buf) + }() + + select { + case <-ctx.Done(): + t.Fatal("Timed out waiting for frames") + case err := <-errCh: + require.NoError(t, err) + } +} diff --git a/control.go b/control.go index b30b44ea3..3009d0ff9 100644 --- a/control.go +++ b/control.go @@ -216,7 +216,7 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { hosts = shuffleHosts(hosts) connCfg := *c.session.connCfg - connCfg.ProtoVersion = 4 // TODO: define maxProtocol + connCfg.ProtoVersion = 5 // TODO: define maxProtocol handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) { // we should never get here, but if we do it means we connected to a @@ -294,7 +294,7 @@ type connHost struct { func (c *controlConn) setupConn(conn *Conn) error { // we need up-to-date host info for the filterHost call below iter := conn.querySystemLocal(context.TODO()) - host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, conn.conn.RemoteAddr().(*net.TCPAddr).Port) + host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, conn.r.RemoteAddr().(*net.TCPAddr).Port) if err != nil { return err } diff --git a/crc.go b/crc.go new file mode 100644 index 000000000..64474ada1 --- /dev/null +++ b/crc.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gocql + +import ( + "hash/crc32" +) + +var ( + // Initial CRC32 bytes: 0xFA, 0x2D, 0x55, 0xCA + initialCRC32Bytes = []byte{0xfa, 0x2d, 0x55, 0xca} +) + +// Crc32 calculates the CRC32 checksum of the given byte slice. +func Crc32(b []byte) uint32 { + crc := crc32.NewIEEE() + crc.Write(initialCRC32Bytes) // Include initial CRC32 bytes + crc.Write(b) + return crc.Sum32() +} + +const ( + crc24Init = 0x875060 // Initial value for CRC24 calculation + crc24Poly = 0x1974F0B // Polynomial for CRC24 calculation +) + +// Crc24 calculates the CRC24 checksum using the Koopman polynomial. +func Crc24(buf []byte) uint32 { + crc := crc24Init + for _, b := range buf { + crc ^= int(b) << 16 + + for i := 0; i < 8; i++ { + crc <<= 1 + if crc&0x1000000 != 0 { + crc ^= crc24Poly + } + } + } + + return uint32(crc) +} diff --git a/crc_test.go b/crc_test.go new file mode 100644 index 000000000..cf5e40a35 --- /dev/null +++ b/crc_test.go @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gocql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestChecksumIEEE(t *testing.T) { + tests := []struct { + name string + buf []byte + expected uint32 + }{ + // expected values are manually generated using crc24 impl in Cassandra + { + name: "empty buf", + buf: []byte{}, + expected: 1148681939, + }, + { + name: "buf filled with 0", + buf: []byte{0, 0, 0, 0, 0}, + expected: 1178391023, + }, + { + name: "buf filled with some data", + buf: []byte{1, 2, 3, 4, 5, 6}, + expected: 3536190002, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, Crc32(tt.buf)) + }) + } +} + +func TestKoopmanChecksum(t *testing.T) { + tests := []struct { + name string + buf []byte + expected uint32 + }{ + // expected values are manually generated using crc32 impl in Cassandra + { + name: "buf filled with 0 (len 3)", + buf: []byte{0, 0, 0}, + expected: 8251255, + }, + { + name: "buf filled with 0 (len 5)", + buf: []byte{0, 0, 0, 0, 0}, + expected: 11185162, + }, + { + name: "buf filled with some data (len 3)", + buf: []byte{64, -30 & 0xff, 1}, + expected: 5891942, + }, + { + name: "buf filled with some data (len 5)", + buf: []byte{64, -30 & 0xff, 1, 0, 0}, + expected: 8775784, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, Crc24(tt.buf)) + }) + } +} diff --git a/frame.go b/frame.go index d374ae574..f8e08ebc3 100644 --- a/frame.go +++ b/frame.go @@ -25,7 +25,9 @@ package gocql import ( + "bytes" "context" + "encoding/binary" "errors" "fmt" "io" @@ -70,6 +72,8 @@ const ( protoVersion5 = 0x05 maxFrameSize = 256 * 1024 * 1024 + + maxSegmentPayloadSize = 0x1FFFF ) type protoVersion byte @@ -168,16 +172,18 @@ const ( flagGlobalTableSpec int = 0x01 flagHasMorePages int = 0x02 flagNoMetaData int = 0x04 + flagMetaDataChanged int = 0x08 // query flags - flagValues byte = 0x01 - flagSkipMetaData byte = 0x02 - flagPageSize byte = 0x04 - flagWithPagingState byte = 0x08 - flagWithSerialConsistency byte = 0x10 - flagDefaultTimestamp byte = 0x20 - flagWithNameValues byte = 0x40 - flagWithKeyspace byte = 0x80 + flagValues uint32 = 0x01 + flagSkipMetaData uint32 = 0x02 + flagPageSize uint32 = 0x04 + flagWithPagingState uint32 = 0x08 + flagWithSerialConsistency uint32 = 0x10 + flagDefaultTimestamp uint32 = 0x20 + flagWithNameValues uint32 = 0x40 + flagWithKeyspace uint32 = 0x80 + flagWithNowInSeconds uint32 = 0x100 // prepare flags flagWithPreparedKeyspace uint32 = 0x01 @@ -524,12 +530,12 @@ func (f *framer) readFrame(r io.Reader, head *frameHeader) error { return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) } - if head.flags&flagCompress == flagCompress { + if f.proto < protoVersion5 && head.flags&flagCompress == flagCompress { if f.compres == nil { return NewErrProtocol("no compressor available with compressed frame body") } - f.buf, err = f.compres.Decode(f.buf) + f.buf, err = f.compres.AppendDecompressedWithLength(nil, f.buf) if err != nil { return err } @@ -768,13 +774,13 @@ func (f *framer) finish() error { return ErrFrameTooBig } - if f.buf[1]&flagCompress == flagCompress { + if f.proto < protoVersion5 && f.buf[1]&flagCompress == flagCompress { if f.compres == nil { panic("compress flag set with no compressor") } // TODO: only compress frames which are big enough - compressed, err := f.compres.Encode(f.buf[f.headSize:]) + compressed, err := f.compres.AppendCompressedWithLength(nil, f.buf[f.headSize:]) if err != nil { return err } @@ -1017,14 +1023,20 @@ type resultMetadata struct { // it is at minimum len(columns) but may be larger, for instance when a column // is a UDT or tuple. actualColCount int + + newMetadataID []byte } func (r *resultMetadata) morePages() bool { return r.flags&flagHasMorePages == flagHasMorePages } +func (r *resultMetadata) noMetaData() bool { + return r.flags&flagNoMetaData == flagNoMetaData +} + func (r resultMetadata) String() string { - return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns) + return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v new_metadata_id=% X]", r.flags, r.pagingState, r.columns, r.newMetadataID) } func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) { @@ -1060,7 +1072,11 @@ func (f *framer) parseResultMetadata() resultMetadata { meta.pagingState = copyBytes(f.readBytes()) } - if meta.flags&flagNoMetaData == flagNoMetaData { + if f.proto > protoVersion4 && meta.flags&flagMetaDataChanged == flagMetaDataChanged { + meta.newMetadataID = copyBytes(f.readShortBytes()) + } + + if meta.noMetaData() { return meta } @@ -1164,18 +1180,24 @@ func (f *framer) parseResultSetKeyspace() frame { type resultPreparedFrame struct { frameHeader - preparedID []byte - reqMeta preparedMetadata - respMeta resultMetadata + preparedID []byte + resultMetadataID []byte + reqMeta preparedMetadata + respMeta resultMetadata } func (f *framer) parseResultPrepared() frame { frame := &resultPreparedFrame{ frameHeader: *f.header, preparedID: f.readShortBytes(), - reqMeta: f.parsePreparedMetadata(), } + if f.proto > protoVersion4 { + frame.resultMetadataID = copyBytes(f.readShortBytes()) + } + + frame.reqMeta = f.parsePreparedMetadata() + if f.proto < protoVersion2 { return frame } @@ -1457,12 +1479,13 @@ type queryParams struct { defaultTimestamp bool defaultTimestampValue int64 // v5+ - keyspace string + keyspace string + nowInSeconds *int } func (q queryParams) String() string { - return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]", - q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace) + return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s now_in_seconds=%v]", + q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace, q.nowInSeconds) } func (f *framer) writeQueryParams(opts *queryParams) { @@ -1472,7 +1495,9 @@ func (f *framer) writeQueryParams(opts *queryParams) { return } - var flags byte + var flags uint32 + names := false + if len(opts.values) > 0 { flags |= flagValues } @@ -1489,8 +1514,6 @@ func (f *framer) writeQueryParams(opts *queryParams) { flags |= flagWithSerialConsistency } - names := false - // protoV3 specific things if f.proto > protoVersion2 { if opts.defaultTimestamp { @@ -1504,17 +1527,23 @@ func (f *framer) writeQueryParams(opts *queryParams) { } if opts.keyspace != "" { - if f.proto > protoVersion4 { - flags |= flagWithKeyspace - } else { + if f.proto < protoVersion5 { panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) } + flags |= flagWithKeyspace + } + + if opts.nowInSeconds != nil { + if f.proto < protoVersion5 { + panic(fmt.Errorf("now_in_seconds can only be set with protocol 5 or higher")) + } + flags |= flagWithNowInSeconds } if f.proto > protoVersion4 { - f.writeUint(uint32(flags)) + f.writeUint(flags) } else { - f.writeByte(flags) + f.writeByte(byte(flags)) } if n := len(opts.values); n > 0 { @@ -1558,6 +1587,10 @@ func (f *framer) writeQueryParams(opts *queryParams) { if opts.keyspace != "" { f.writeString(opts.keyspace) } + + if opts.nowInSeconds != nil { + f.writeInt(int32(*opts.nowInSeconds)) + } } type writeQueryFrame struct { @@ -1604,6 +1637,9 @@ type writeExecuteFrame struct { // v4+ customPayload map[string][]byte + + // v5+ + resultMetadataID []byte } func (e *writeExecuteFrame) String() string { @@ -1611,16 +1647,21 @@ func (e *writeExecuteFrame) String() string { } func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { - return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) + return fr.writeExecuteFrame(streamID, e.preparedID, e.resultMetadataID, &e.params, &e.customPayload) } -func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { +func (f *framer) writeExecuteFrame(streamID int, preparedID, resultMetadataID []byte, params *queryParams, customPayload *map[string][]byte) error { if len(*customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opExecute, streamID) f.writeCustomPayload(customPayload) f.writeShortBytes(preparedID) + + if f.proto > protoVersion4 { + f.writeShortBytes(resultMetadataID) + } + if f.proto > protoVersion1 { f.writeQueryParams(params) } else { @@ -1659,6 +1700,10 @@ type writeBatchFrame struct { //v4+ customPayload map[string][]byte + + //v5+ + keyspace string + nowInSeconds *int } func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { @@ -1676,7 +1721,7 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload n := len(w.statements) f.writeShort(uint16(n)) - var flags byte + var flags uint32 for i := 0; i < n; i++ { b := &w.statements[i] @@ -1717,26 +1762,48 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload if w.defaultTimestamp { flags |= flagDefaultTimestamp } + } - if f.proto > protoVersion4 { - f.writeUint(uint32(flags)) - } else { - f.writeByte(flags) + if w.keyspace != "" { + if f.proto < protoVersion5 { + panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) } + flags |= flagWithKeyspace + } - if w.serialConsistency > 0 { - f.writeConsistency(Consistency(w.serialConsistency)) + if w.nowInSeconds != nil { + if f.proto < protoVersion5 { + panic(fmt.Errorf("now_in_seconds can only be set with protocol 5 or higher")) } + flags |= flagWithNowInSeconds + } - if w.defaultTimestamp { - var ts int64 - if w.defaultTimestampValue != 0 { - ts = w.defaultTimestampValue - } else { - ts = time.Now().UnixNano() / 1000 - } - f.writeLong(ts) + if f.proto > protoVersion4 { + f.writeUint(flags) + } else { + f.writeByte(byte(flags)) + } + + if w.serialConsistency > 0 { + f.writeConsistency(Consistency(w.serialConsistency)) + } + + if w.defaultTimestamp { + var ts int64 + if w.defaultTimestampValue != 0 { + ts = w.defaultTimestampValue + } else { + ts = time.Now().UnixNano() / 1000 } + f.writeLong(ts) + } + + if w.keyspace != "" { + f.writeString(w.keyspace) + } + + if w.nowInSeconds != nil { + f.writeInt(int32(*w.nowInSeconds)) } return f.finish() @@ -2070,3 +2137,262 @@ func (f *framer) writeBytesMap(m map[string][]byte) { f.writeBytes(v) } } + +func (f *framer) prepareModernLayout() error { + // Ensure protocol version is V5 or higher + if f.proto < protoVersion5 { + panic("Modern layout is not supported with version V4 or less") + } + + selfContained := true + + var ( + adjustedBuf []byte + tempBuf []byte + err error + ) + + // Process the buffer in chunks if it exceeds the max payload size + for len(f.buf) > maxSegmentPayloadSize { + if f.compres != nil { + tempBuf, err = newCompressedSegment(f.buf[:maxSegmentPayloadSize], false, f.compres) + } else { + tempBuf, err = newUncompressedSegment(f.buf[:maxSegmentPayloadSize], false) + } + if err != nil { + return err + } + + adjustedBuf = append(adjustedBuf, tempBuf...) + f.buf = f.buf[maxSegmentPayloadSize:] + selfContained = false + } + + // Process the remaining buffer + if f.compres != nil { + tempBuf, err = newCompressedSegment(f.buf, selfContained, f.compres) + } else { + tempBuf, err = newUncompressedSegment(f.buf, selfContained) + } + if err != nil { + return err + } + + adjustedBuf = append(adjustedBuf, tempBuf...) + f.buf = adjustedBuf + + return nil +} + +const ( + crc24Size = 3 + crc32Size = 4 +) + +func readUncompressedSegment(r io.Reader) ([]byte, bool, error) { + const ( + headerSize = 3 + ) + + header := [headerSize + crc24Size]byte{} + + // Read the frame header + if _, err := io.ReadFull(r, header[:]); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read uncompressed frame, err: %w", err) + } + + // Compute and verify the header CRC24 + computedHeaderCRC24 := Crc24(header[:headerSize]) + readHeaderCRC24 := uint32(header[3]) | uint32(header[4])<<8 | uint32(header[5])<<16 + if computedHeaderCRC24 != readHeaderCRC24 { + return nil, false, fmt.Errorf("gocql: crc24 mismatch in frame header, computed: %d, got: %d", computedHeaderCRC24, readHeaderCRC24) + } + + // Extract the payload length and self-contained flag + headerInt := uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16 + payloadLen := int(headerInt & maxSegmentPayloadSize) + isSelfContained := (headerInt & (1 << 17)) != 0 + + // Read the payload + payload := make([]byte, payloadLen) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read uncompressed frame payload, err: %w", err) + } + + // Read and verify the payload CRC32 + if _, err := io.ReadFull(r, header[:crc32Size]); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read payload crc32, err: %w", err) + } + + computedPayloadCRC32 := Crc32(payload) + readPayloadCRC32 := binary.LittleEndian.Uint32(header[:crc32Size]) + if computedPayloadCRC32 != readPayloadCRC32 { + return nil, false, fmt.Errorf("gocql: payload crc32 mismatch, computed: %d, got: %d", computedPayloadCRC32, readPayloadCRC32) + } + + return payload, isSelfContained, nil +} + +func newUncompressedSegment(payload []byte, isSelfContained bool) ([]byte, error) { + const ( + headerSize = 6 + selfContainedBit = 1 << 17 + ) + + payloadLen := len(payload) + if payloadLen > maxSegmentPayloadSize { + return nil, fmt.Errorf("gocql: payload length (%d) exceeds maximum size of %d", payloadLen, maxSegmentPayloadSize) + } + + // Create the segment + segmentSize := headerSize + payloadLen + crc32Size + segment := make([]byte, segmentSize) + + // First 3 bytes: payload length and self-contained flag + headerInt := uint32(payloadLen) + if isSelfContained { + headerInt |= selfContainedBit // Set the self-contained flag + } + + // Encode the first 3 bytes as a single little-endian integer + segment[0] = byte(headerInt) + segment[1] = byte(headerInt >> 8) + segment[2] = byte(headerInt >> 16) + + // Calculate CRC24 for the first 3 bytes of the header + crc := Crc24(segment[:3]) + + // Encode CRC24 into the next 3 bytes of the header + segment[3] = byte(crc) + segment[4] = byte(crc >> 8) + segment[5] = byte(crc >> 16) + + copy(segment[headerSize:], payload) // Copy the payload to the segment + + // Calculate CRC32 for the payload + payloadCRC32 := Crc32(payload) + binary.LittleEndian.PutUint32(segment[headerSize+payloadLen:], payloadCRC32) + + return segment, nil +} + +func newCompressedSegment(uncompressedPayload []byte, isSelfContained bool, compressor Compressor) ([]byte, error) { + const ( + headerSize = 5 + selfContainedBit = 1 << 34 + ) + + uncompressedLen := len(uncompressedPayload) + if uncompressedLen > maxSegmentPayloadSize { + return nil, fmt.Errorf("gocql: payload length (%d) exceeds maximum size of %d", uncompressedPayload, maxSegmentPayloadSize) + } + + compressedPayload, err := compressor.AppendCompressed(nil, uncompressedPayload) + if err != nil { + return nil, err + } + + compressedLen := len(compressedPayload) + + // Compression is not worth it + if uncompressedLen < compressedLen { + // native_protocol_v5.spec + // 2.2 + // An uncompressed length of 0 signals that the compressed payload + // should be used as-is and not decompressed. + compressedPayload = uncompressedPayload + compressedLen = uncompressedLen + uncompressedLen = 0 + } + + // Combine compressed and uncompressed lengths and set the self-contained flag if needed + combined := uint64(compressedLen) | uint64(uncompressedLen)<<17 + if isSelfContained { + combined |= selfContainedBit + } + + var headerBuf [headerSize + crc24Size]byte + + // Write the combined value into the header buffer + binary.LittleEndian.PutUint64(headerBuf[:], combined) + + // Create a buffer with enough capacity to hold the header, compressed payload, and checksums + buf := bytes.NewBuffer(make([]byte, 0, headerSize+crc24Size+compressedLen+crc32Size)) + + // Write the first 5 bytes of the header (compressed and uncompressed sizes) + buf.Write(headerBuf[:headerSize]) + + // Compute and write the CRC24 checksum of the first 5 bytes + headerChecksum := Crc24(headerBuf[:headerSize]) + + // LittleEndian 3 bytes + headerBuf[0] = byte(headerChecksum) + headerBuf[1] = byte(headerChecksum >> 8) + headerBuf[2] = byte(headerChecksum >> 16) + buf.Write(headerBuf[:3]) + + buf.Write(compressedPayload) + + // Compute and write the CRC32 checksum of the payload + payloadChecksum := Crc32(compressedPayload) + binary.LittleEndian.PutUint32(headerBuf[:], payloadChecksum) + buf.Write(headerBuf[:4]) + + return buf.Bytes(), nil +} + +func readCompressedSegment(r io.Reader, compressor Compressor) ([]byte, bool, error) { + const headerSize = 5 + var ( + headerBuf [headerSize + crc24Size]byte + err error + ) + + if _, err = io.ReadFull(r, headerBuf[:]); err != nil { + return nil, false, err + } + + // Reading checksum from frame header + readHeaderChecksum := uint32(headerBuf[5]) | uint32(headerBuf[6])<<8 | uint32(headerBuf[7])<<16 + if computedHeaderChecksum := Crc24(headerBuf[:headerSize]); computedHeaderChecksum != readHeaderChecksum { + return nil, false, fmt.Errorf("gocql: crc24 mismatch in frame header, read: %d, computed: %d", readHeaderChecksum, computedHeaderChecksum) + } + + // First 17 bits - payload size after compression + compressedLen := uint32(headerBuf[0]) | uint32(headerBuf[1])<<8 | uint32(headerBuf[2]&0x1)<<16 + + // The next 17 bits - payload size before compression + uncompressedLen := (uint32(headerBuf[2]) >> 1) | uint32(headerBuf[3])<<7 | uint32(headerBuf[4]&0b11)<<15 + + // Self-contained flag + selfContained := (headerBuf[4] & 0b100) != 0 + + compressedPayload := make([]byte, compressedLen) + if _, err = io.ReadFull(r, compressedPayload); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read compressed frame payload, err: %w", err) + } + + if _, err = io.ReadFull(r, headerBuf[:crc32Size]); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read payload crc32, err: %w", err) + } + + // Ensuring if payload checksum matches + readPayloadChecksum := binary.LittleEndian.Uint32(headerBuf[:crc32Size]) + if computedPayloadChecksum := Crc32(compressedPayload); readPayloadChecksum != computedPayloadChecksum { + return nil, false, fmt.Errorf("gocql: crc32 mismatch in payload, read: %d, computed: %d", readPayloadChecksum, computedPayloadChecksum) + } + + var uncompressedPayload []byte + if uncompressedLen > 0 { + if uncompressedPayload, err = compressor.AppendDecompressed(nil, compressedPayload, uncompressedLen); err != nil { + return nil, false, err + } + if uint32(len(uncompressedPayload)) != uncompressedLen { + return nil, false, fmt.Errorf("gocql: length mismatch after payload decoding, got %d, expected %d", len(uncompressedPayload), uncompressedLen) + } + } else { + uncompressedPayload = compressedPayload + } + + return uncompressedPayload, selfContained, nil +} diff --git a/frame_test.go b/frame_test.go index 170cba710..8cb9024a5 100644 --- a/frame_test.go +++ b/frame_test.go @@ -26,8 +26,12 @@ package gocql import ( "bytes" + "errors" "os" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestFuzzBugs(t *testing.T) { @@ -127,3 +131,313 @@ func TestFrameReadTooLong(t *testing.T) { t.Fatalf("expected to get header %v got %v", opReady, head.op) } } + +func Test_framer_writeExecuteFrame(t *testing.T) { + framer := newFramer(nil, protoVersion5) + nowInSeconds := 123 + frame := writeExecuteFrame{ + preparedID: []byte{1, 2, 3}, + resultMetadataID: []byte{4, 5, 6}, + customPayload: map[string][]byte{ + "key1": []byte("value1"), + }, + params: queryParams{ + nowInSeconds: &nowInSeconds, + keyspace: "test_keyspace", + }, + } + + err := framer.writeExecuteFrame(123, frame.preparedID, frame.resultMetadataID, &frame.params, &frame.customPayload) + if err != nil { + t.Fatal(err) + } + + // skipping header + framer.buf = framer.buf[9:] + + assertDeepEqual(t, "customPayload", frame.customPayload, framer.readBytesMap()) + assertDeepEqual(t, "preparedID", frame.preparedID, framer.readShortBytes()) + assertDeepEqual(t, "resultMetadataID", frame.resultMetadataID, framer.readShortBytes()) + assertDeepEqual(t, "constistency", frame.params.consistency, Consistency(framer.readShort())) + + flags := framer.readInt() + if flags&int(flagWithNowInSeconds) != int(flagWithNowInSeconds) { + t.Fatal("expected flagNowInSeconds to be set, but it is not") + } + + if flags&int(flagWithKeyspace) != int(flagWithKeyspace) { + t.Fatal("expected flagWithKeyspace to be set, but it is not") + } + + assertDeepEqual(t, "keyspace", frame.params.keyspace, framer.readString()) + assertDeepEqual(t, "nowInSeconds", nowInSeconds, framer.readInt()) +} + +func Test_framer_writeBatchFrame(t *testing.T) { + framer := newFramer(nil, protoVersion5) + nowInSeconds := 123 + frame := writeBatchFrame{ + customPayload: map[string][]byte{ + "key1": []byte("value1"), + }, + nowInSeconds: &nowInSeconds, + } + + err := framer.writeBatchFrame(123, &frame, frame.customPayload) + if err != nil { + t.Fatal(err) + } + + // skipping header + framer.buf = framer.buf[9:] + + assertDeepEqual(t, "customPayload", frame.customPayload, framer.readBytesMap()) + assertDeepEqual(t, "typ", frame.typ, BatchType(framer.readByte())) + assertDeepEqual(t, "len(statements)", len(frame.statements), int(framer.readShort())) + assertDeepEqual(t, "consistency", frame.consistency, Consistency(framer.readShort())) + + flags := framer.readInt() + if flags&int(flagWithNowInSeconds) != int(flagWithNowInSeconds) { + t.Fatal("expected flagNowInSeconds to be set, but it is not") + } + + assertDeepEqual(t, "nowInSeconds", nowInSeconds, framer.readInt()) +} + +type testMockedCompressor struct { + // this is an error its methods should return + expectedError error + + // invalidateDecodedDataLength allows to simulate data decoding invalidation + invalidateDecodedDataLength bool +} + +func (m testMockedCompressor) Name() string { + return "testMockedCompressor" +} + +func (m testMockedCompressor) AppendCompressed(_, src []byte) ([]byte, error) { + if m.expectedError != nil { + return nil, m.expectedError + } + return src, nil +} + +func (m testMockedCompressor) AppendDecompressed(_, src []byte, decompressedLength uint32) ([]byte, error) { + if m.expectedError != nil { + return nil, m.expectedError + } + + // simulating invalid size of decoded data + if m.invalidateDecodedDataLength { + return src[:decompressedLength-1], nil + } + + return src, nil +} + +func (m testMockedCompressor) AppendCompressedWithLength(dst, src []byte) ([]byte, error) { + panic("testMockedCompressor.AppendCompressedWithLength is not implemented") +} + +func (m testMockedCompressor) AppendDecompressedWithLength(dst, src []byte) ([]byte, error) { + panic("testMockedCompressor.AppendDecompressedWithLength is not implemented") +} + +func Test_readUncompressedFrame(t *testing.T) { + tests := []struct { + name string + modifyFrame func([]byte) []byte + expectedErr string + }{ + { + name: "header crc24 mismatch", + modifyFrame: func(frame []byte) []byte { + // simulating some crc invalidation + frame[0] = 255 + return frame + }, + expectedErr: "gocql: crc24 mismatch in frame header", + }, + { + name: "body crc32 mismatch", + modifyFrame: func(frame []byte) []byte { + // simulating body crc32 mismatch + frame[len(frame)-1] = 255 + return frame + }, + expectedErr: "gocql: payload crc32 mismatch", + }, + { + name: "invalid frame length", + modifyFrame: func(frame []byte) []byte { + // simulating body length invalidation + frame = frame[:7] + return frame + }, + expectedErr: "gocql: failed to read uncompressed frame payload", + }, + { + name: "cannot read body checksum", + modifyFrame: func(frame []byte) []byte { + // simulating body length invalidation + frame = frame[:len(frame)-4] + return frame + }, + expectedErr: "gocql: failed to read payload crc32", + }, + { + name: "success", + modifyFrame: nil, + expectedErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + framer := newFramer(nil, protoVersion5) + req := writeQueryFrame{ + statement: "SELECT * FROM system.local", + params: queryParams{ + consistency: Quorum, + keyspace: "gocql_test", + }, + } + + err := req.buildFrame(framer, 128) + require.NoError(t, err) + + frame, err := newUncompressedSegment(framer.buf, true) + require.NoError(t, err) + + if tt.modifyFrame != nil { + frame = tt.modifyFrame(frame) + } + + readFrame, isSelfContained, err := readUncompressedSegment(bytes.NewReader(frame)) + + if tt.expectedErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedErr) + } else { + require.NoError(t, err) + assert.True(t, isSelfContained) + assert.Equal(t, framer.buf, readFrame) + } + }) + } +} + +func Test_readCompressedFrame(t *testing.T) { + tests := []struct { + name string + // modifyFrameFn is useful for simulating frame data invalidation + modifyFrameFn func([]byte) []byte + compressor testMockedCompressor + + // expectedErrorMsg is an error message that should be returned by Error() method. + // We need this to understand which of fmt.Errorf() is returned + expectedErrorMsg string + }{ + { + name: "header crc24 mismatch", + modifyFrameFn: func(frame []byte) []byte { + // simulating some crc invalidation + frame[0] = 255 + return frame + }, + expectedErrorMsg: "gocql: crc24 mismatch in frame header", + }, + { + name: "body crc32 mismatch", + modifyFrameFn: func(frame []byte) []byte { + // simulating body crc32 mismatch + frame[len(frame)-1] = 255 + return frame + }, + expectedErrorMsg: "gocql: crc32 mismatch in payload", + }, + { + name: "invalid frame length", + modifyFrameFn: func(frame []byte) []byte { + // simulating body length invalidation + return frame[:12] + }, + expectedErrorMsg: "gocql: failed to read compressed frame payload", + }, + { + name: "cannot read body checksum", + modifyFrameFn: func(frame []byte) []byte { + // simulating body length invalidation + return frame[:len(frame)-4] + }, + expectedErrorMsg: "gocql: failed to read payload crc32", + }, + { + name: "failed to encode payload", + modifyFrameFn: nil, + compressor: testMockedCompressor{ + expectedError: errors.New("failed to encode payload"), + }, + expectedErrorMsg: "failed to encode payload", + }, + { + name: "failed to decode payload", + modifyFrameFn: nil, + compressor: testMockedCompressor{ + expectedError: errors.New("failed to decode payload"), + }, + expectedErrorMsg: "failed to decode payload", + }, + { + name: "length mismatch after decoding", + modifyFrameFn: nil, + compressor: testMockedCompressor{ + invalidateDecodedDataLength: true, + }, + expectedErrorMsg: "gocql: length mismatch after payload decoding", + }, + { + name: "success", + modifyFrameFn: nil, + expectedErrorMsg: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + framer := newFramer(nil, protoVersion5) + req := writeQueryFrame{ + statement: "SELECT * FROM system.local", + params: queryParams{ + consistency: Quorum, + keyspace: "gocql_test", + }, + } + + err := req.buildFrame(framer, 128) + require.NoError(t, err) + + frame, err := newCompressedSegment(framer.buf, true, testMockedCompressor{}) + require.NoError(t, err) + + if tt.modifyFrameFn != nil { + frame = tt.modifyFrameFn(frame) + } + + readFrame, selfContained, err := readCompressedSegment(bytes.NewReader(frame), tt.compressor) + + switch { + case tt.expectedErrorMsg != "": + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedErrorMsg) + case tt.compressor.expectedError != nil: + require.ErrorIs(t, err, tt.compressor.expectedError) + default: + require.NoError(t, err) + assert.True(t, selfContained) + assert.Equal(t, framer.buf, readFrame) + } + }) + } +} diff --git a/go.mod b/go.mod index 0aea881ec..af4ee9e22 100644 --- a/go.mod +++ b/go.mod @@ -18,13 +18,23 @@ module github.com/gocql/gocql require ( - github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect - github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect + github.com/gocql/gocql/lz4 v0.0.0-00010101000000-000000000000 github.com/golang/snappy v0.0.3 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed - github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.3.0 // indirect + github.com/stretchr/testify v1.9.0 gopkg.in/inf.v0 v0.9.1 ) -go 1.13 +require ( + github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect + github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/pierrec/lz4/v4 v4.1.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace github.com/gocql/gocql/lz4 => ./lz4 + +go 1.19 diff --git a/go.sum b/go.sum index 2e3892bcb..c665f352b 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,9 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYE github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= @@ -13,10 +14,18 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4= +github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lz4/lz4.go b/lz4/lz4.go index 049fdc0bb..c836a0934 100644 --- a/lz4/lz4.go +++ b/lz4/lz4.go @@ -27,7 +27,6 @@ package lz4 import ( "encoding/binary" "fmt" - "github.com/pierrec/lz4/v4" ) @@ -47,29 +46,71 @@ func (s LZ4Compressor) Name() string { return "lz4" } -func (s LZ4Compressor) Encode(data []byte) ([]byte, error) { - buf := make([]byte, lz4.CompressBlockBound(len(data)+4)) +const dataLengthSize = 4 + +func (s LZ4Compressor) AppendCompressedWithLength(dst, src []byte) ([]byte, error) { + maxLength := lz4.CompressBlockBound(len(src)) + oldDstLen := len(dst) + dst = grow(dst, maxLength+dataLengthSize) + var compressor lz4.Compressor - n, err := compressor.CompressBlock(data, buf[4:]) + n, err := compressor.CompressBlock(src, dst[oldDstLen+dataLengthSize:]) // According to lz4.CompressBlock doc, it doesn't fail as long as the dst // buffer length is at least lz4.CompressBlockBound(len(data))) bytes, but // we check for error anyway just to be thorough. if err != nil { return nil, err } - binary.BigEndian.PutUint32(buf, uint32(len(data))) - return buf[:n+4], nil + binary.BigEndian.PutUint32(dst[oldDstLen:oldDstLen+dataLengthSize], uint32(len(src))) + return dst[:oldDstLen+n+dataLengthSize], nil +} + +func (s LZ4Compressor) AppendDecompressedWithLength(dst, src []byte) ([]byte, error) { + if len(src) < dataLengthSize { + return nil, fmt.Errorf("cassandra lz4 block size should be >4, got=%d", len(src)) + } + uncompressedLength := binary.BigEndian.Uint32(src[:dataLengthSize]) + if uncompressedLength == 0 { + return nil, nil + } + oldDstLen := len(dst) + dst = grow(dst, int(uncompressedLength)) + n, err := lz4.UncompressBlock(src[dataLengthSize:], dst[oldDstLen:]) + return dst[:oldDstLen+n], err + } -func (s LZ4Compressor) Decode(data []byte) ([]byte, error) { - if len(data) < 4 { - return nil, fmt.Errorf("cassandra lz4 block size should be >4, got=%d", len(data)) +func (s LZ4Compressor) AppendCompressed(dst, src []byte) ([]byte, error) { + maxLength := lz4.CompressBlockBound(len(src)) + oldDstLen := len(dst) + dst = grow(dst, maxLength) + + var compressor lz4.Compressor + n, err := compressor.CompressBlock(src, dst[oldDstLen:]) + if err != nil { + return nil, err } - uncompressedLength := binary.BigEndian.Uint32(data) + + return dst[:oldDstLen+n], nil +} + +func (s LZ4Compressor) AppendDecompressed(dst, src []byte, uncompressedLength uint32) ([]byte, error) { if uncompressedLength == 0 { return nil, nil } - buf := make([]byte, uncompressedLength) - n, err := lz4.UncompressBlock(data[4:], buf) - return buf[:n], err + oldDstLen := len(dst) + dst = grow(dst, int(uncompressedLength)) + n, err := lz4.UncompressBlock(src, dst[oldDstLen:]) + return dst[:oldDstLen+n], err +} + +// grow grows b to guaranty space for n elements, if needed. +func grow(b []byte, n int) []byte { + oldLen := len(b) + if cap(b)-oldLen < n { + newBuf := make([]byte, oldLen+n) + copy(newBuf, b) + b = newBuf + } + return b[:oldLen+n] } diff --git a/lz4/lz4_test.go b/lz4/lz4_test.go index e0834b948..379afd4d8 100644 --- a/lz4/lz4_test.go +++ b/lz4/lz4_test.go @@ -25,6 +25,7 @@ package lz4 import ( + "github.com/pierrec/lz4/v4" "testing" "github.com/stretchr/testify/require" @@ -34,21 +35,215 @@ func TestLZ4Compressor(t *testing.T) { var c LZ4Compressor require.Equal(t, "lz4", c.Name()) - _, err := c.Decode([]byte{0, 1, 2}) + _, err := c.AppendDecompressedWithLength(nil, []byte{0, 1, 2}) require.EqualError(t, err, "cassandra lz4 block size should be >4, got=3") - _, err = c.Decode([]byte{0, 1, 2, 4, 5}) + _, err = c.AppendDecompressedWithLength(nil, []byte{0, 1, 2, 4, 5}) require.EqualError(t, err, "lz4: invalid source or destination buffer too short") // If uncompressed size is zero then nothing is decoded even if present. - decoded, err := c.Decode([]byte{0, 0, 0, 0, 5, 7, 8}) + decoded, err := c.AppendDecompressedWithLength(nil, []byte{0, 0, 0, 0, 5, 7, 8}) require.NoError(t, err) require.Nil(t, decoded) original := []byte("My Test String") - encoded, err := c.Encode(original) + encoded, err := c.AppendCompressedWithLength(nil, original) require.NoError(t, err) - decoded, err = c.Decode(encoded) + decoded, err = c.AppendDecompressedWithLength(nil, encoded) require.NoError(t, err) require.Equal(t, original, decoded) } + +func TestLZ4Compressor_AppendCompressedDecompressed(t *testing.T) { + c := LZ4Compressor{} + + invalidUncompressedLength := uint32(10) + _, err := c.AppendDecompressed(nil, []byte{0, 1, 2, 4, 5}, invalidUncompressedLength) + require.EqualError(t, err, "lz4: invalid source or destination buffer too short") + + original := []byte("My Test String") + encoded, err := c.AppendCompressed(nil, original) + require.NoError(t, err) + decoded, err := c.AppendDecompressed(nil, encoded, uint32(len(original))) + require.NoError(t, err) + require.Equal(t, original, decoded) +} + +func TestLZ4Compressor_AppendWithLengthGrowSliceWithData(t *testing.T) { + var tests = []struct { + name string + src []byte + dst []byte + shouldReuseDst bool + decodeDst []byte + shouldReuseDecodeDst bool + }{ + { + name: "both dst are empty", + src: []byte("small data"), + dst: nil, + decodeDst: nil, + }, + { + name: "dst is nil", + src: []byte("another piece of data"), + dst: nil, + decodeDst: []byte("something"), + }, + { + name: "decodeDst is nil", + src: []byte("another piece of data"), + dst: []byte("some"), + decodeDst: nil, + }, + { + name: "both dst are not empty", + src: []byte("another piece of data"), + dst: []byte("dst"), + decodeDst: []byte("decodeDst"), + }, + { + name: "both dst slices have enough capacity", + src: []byte("small"), + dst: createBufWithCapAndData("cap=128", 128), + shouldReuseDst: true, + decodeDst: createBufWithCapAndData("cap=256", 256), + shouldReuseDecodeDst: true, + }, + { + name: "both dsts have some data and not enough capacity", + src: []byte("small"), + dst: createBufWithCapAndData("data", 6), + decodeDst: createBufWithCapAndData("wow", 4), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + compressor := LZ4Compressor{} + + // Appending compressed data to dst, + // expecting that dst still contains "test" + result, err := compressor.AppendCompressedWithLength(tt.dst, tt.src) + require.NoError(t, err) + + var expectedCap int + if tt.shouldReuseDst { + expectedCap = cap(tt.dst) + } else { + expectedCap = len(tt.dst) + lz4.CompressBlockBound(len(tt.src)) + dataLengthSize + } + + require.Equal(t, expectedCap, cap(result)) + if len(tt.dst) > 0 { + require.Equal(t, tt.dst, result[:len(tt.dst)]) + } + + result, err = compressor.AppendDecompressedWithLength(tt.decodeDst, result[len(tt.dst):]) + require.NoError(t, err) + + var expectedDecodeCap int + if tt.shouldReuseDecodeDst { + expectedDecodeCap = cap(tt.decodeDst) + } else { + expectedDecodeCap = len(tt.decodeDst) + len(tt.src) + } + + require.Equal(t, expectedDecodeCap, cap(result)) + require.Equal(t, tt.src, result[len(tt.decodeDst):]) + }) + } +} + +func TestLZ4Compressor_AppendGrowSliceWithData(t *testing.T) { + var tests = []struct { + name string + src []byte + dst []byte + shouldReuseDst bool + decodeDst []byte + shouldReuseDecodeDst bool + }{ + { + name: "both dst are empty", + src: []byte("small data"), + dst: nil, + decodeDst: nil, + }, + { + name: "dst is nil", + src: []byte("another piece of data"), + dst: nil, + decodeDst: []byte("something"), + }, + { + name: "decodeDst is nil", + src: []byte("another piece of data"), + dst: []byte("some"), + decodeDst: nil, + }, + { + name: "both dst are not empty", + src: []byte("another piece of data"), + dst: []byte("dst"), + decodeDst: []byte("decodeDst"), + }, + { + name: "both dst slices have enough capacity", + src: []byte("small"), + dst: createBufWithCapAndData("cap=128", 128), + shouldReuseDst: true, + decodeDst: createBufWithCapAndData("cap=256", 256), + shouldReuseDecodeDst: true, + }, + { + name: "both dst slices have some data and not enough capacity", + src: []byte("small"), + dst: createBufWithCapAndData("data", 6), + decodeDst: createBufWithCapAndData("wow", 4), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + compressor := LZ4Compressor{} + + // Appending compressed data to dst, + // expecting that dst still contains "test" + result, err := compressor.AppendCompressed(tt.dst, tt.src) + require.NoError(t, err) + + var expectedCap int + if tt.shouldReuseDst { + expectedCap = cap(tt.dst) + } else { + expectedCap = len(tt.dst) + lz4.CompressBlockBound(len(tt.src)) + } + + require.Equal(t, expectedCap, cap(result)) + if len(tt.dst) > 0 { + require.Equal(t, tt.dst, result[:len(tt.dst)]) + } + + uncompressedLen := uint32(len(tt.src)) + result, err = compressor.AppendDecompressed(tt.decodeDst, result[len(tt.dst):], uncompressedLen) + require.NoError(t, err) + + var expectedDecodeCap int + if tt.shouldReuseDst { + expectedDecodeCap = cap(tt.decodeDst) + } else { + expectedDecodeCap = len(tt.decodeDst) + len(tt.src) + } + + require.Equal(t, expectedDecodeCap, cap(result)) + require.Equal(t, tt.src, result[len(tt.decodeDst):]) + }) + } +} + +func createBufWithCapAndData(data string, cap int) []byte { + buf := make([]byte, cap) + copy(buf, data) + return buf[:len(data)] +} diff --git a/prepared_cache.go b/prepared_cache.go index 3fd256d33..7f5533a2d 100644 --- a/prepared_cache.go +++ b/prepared_cache.go @@ -100,3 +100,20 @@ func (p *preparedLRU) evictPreparedID(key string, id []byte) { } } + +func (p *preparedLRU) get(key string) (*inflightPrepare, bool) { + p.mu.Lock() + defer p.mu.Unlock() + + val, ok := p.lru.Get(key) + if !ok { + return nil, false + } + + ifp, ok := val.(*inflightPrepare) + if !ok { + return nil, false + } + + return ifp, true +} diff --git a/session.go b/session.go index a600b95f3..774b2cc09 100644 --- a/session.go +++ b/session.go @@ -591,11 +591,20 @@ func (s *Session) getConn() *Conn { return nil } -// returns routing key indexes and type info -func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) { +// Returns routing key indexes and type info. +// If keyspace == "" it uses the keyspace which is specified in Cluster.Keyspace +func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace string) (*routingKeyInfo, error) { + if keyspace == "" { + keyspace = s.cfg.Keyspace + } + + routingKeyInfoCacheKey := keyspace + stmt + s.routingKeyInfoCache.mu.Lock() - entry, cached := s.routingKeyInfoCache.lru.Get(stmt) + // Using here keyspace + stmt as a cache key because + // the query keyspace could be overridden via SetKeyspace + entry, cached := s.routingKeyInfoCache.lru.Get(routingKeyInfoCacheKey) if cached { // done accessing the cache s.routingKeyInfoCache.mu.Unlock() @@ -619,7 +628,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI inflight := new(inflightCachedEntry) inflight.wg.Add(1) defer inflight.wg.Done() - s.routingKeyInfoCache.lru.Add(stmt, inflight) + s.routingKeyInfoCache.lru.Add(routingKeyInfoCacheKey, inflight) s.routingKeyInfoCache.mu.Unlock() var ( @@ -635,7 +644,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } // get the query info for the statement - info, inflight.err = conn.prepareStatement(ctx, stmt, nil) + info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace) if inflight.err != nil { // don't cache this error s.routingKeyInfoCache.Remove(stmt) @@ -651,7 +660,9 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } table := info.request.table - keyspace := info.request.keyspace + if info.request.keyspace != "" { + keyspace = info.request.keyspace + } if len(info.request.pkeyColumns) > 0 { // proto v4 dont need to calculate primary key columns @@ -936,6 +947,9 @@ type Query struct { // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo + + keyspace string + nowInSecondsValue *int } type queryRoutingInfo struct { @@ -1143,6 +1157,9 @@ func (q *Query) Keyspace() string { if q.routingInfo.keyspace != "" { return q.routingInfo.keyspace } + if q.keyspace != "" { + return q.keyspace + } if q.session == nil { return "" @@ -1174,7 +1191,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) { } // try to determine the routing key - routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt) + routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt, q.keyspace) if err != nil { return nil, err } @@ -1423,6 +1440,24 @@ func (q *Query) releaseAfterExecution() { q.decRefCount() } +// SetKeyspace will enable keyspace flag on the query. +// It allows to specify the keyspace that the query should be executed in +// +// Only available on protocol >= 5. +func (q *Query) SetKeyspace(keyspace string) *Query { + q.keyspace = keyspace + return q +} + +// WithNowInSeconds will enable the with now_in_seconds flag on the query. +// Also, it allows to define now_in_seconds value. +// +// Only available on protocol >= 5. +func (q *Query) WithNowInSeconds(now int) *Query { + q.nowInSecondsValue = &now + return q +} + // Iter represents an iterator that can be used to iterate over all rows that // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. @@ -1742,6 +1777,7 @@ type Batch struct { cancelBatch func() keyspace string metrics *queryMetrics + nowInSeconds *int // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo @@ -1987,7 +2023,7 @@ func (b *Batch) GetRoutingKey() ([]byte, error) { return nil, nil } // try to determine the routing key - routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt) + routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.keyspace) if err != nil { return nil, err } @@ -2042,6 +2078,24 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } +// SetKeyspace will enable keyspace flag on the query. +// It allows to specify the keyspace that the query should be executed in +// +// Only available on protocol >= 5. +func (b *Batch) SetKeyspace(keyspace string) *Batch { + b.keyspace = keyspace + return b +} + +// WithNowInSeconds will enable the with now_in_seconds flag on the query. +// Also, it allows to define now_in_seconds value. +// +// Only available on protocol >= 5. +func (b *Batch) WithNowInSeconds(now int) *Batch { + b.nowInSeconds = &now + return b +} + type BatchType byte const ( diff --git a/snappy/snappy_compressor.go b/snappy/snappy_compressor.go new file mode 100644 index 000000000..faec4a722 --- /dev/null +++ b/snappy/snappy_compressor.go @@ -0,0 +1,28 @@ +package snappy + +import "github.com/golang/snappy" + +// SnappyCompressor implements the Compressor interface and can be used to +// compress incoming and outgoing frames. The snappy compression algorithm +// aims for very high speeds and reasonable compression. +type SnappyCompressor struct{} + +func (s SnappyCompressor) Name() string { + return "snappy" +} + +func (s SnappyCompressor) AppendCompressedWithLength(dst, src []byte) ([]byte, error) { + return snappy.Encode(dst, src), nil +} + +func (s SnappyCompressor) AppendDecompressedWithLength(dst, src []byte) ([]byte, error) { + return snappy.Decode(dst, src) +} + +func (s SnappyCompressor) AppendCompressed(dst, src []byte) ([]byte, error) { + panic("SnappyCompressor.AppendCompressed is not supported") +} + +func (s SnappyCompressor) AppendDecompressed(dst, src []byte, decompressedLength uint32) ([]byte, error) { + panic("SnappyCompressor.AppendDecompressed is not supported") +} diff --git a/compressor_test.go b/snappy/snappy_test.go similarity index 89% rename from compressor_test.go rename to snappy/snappy_test.go index 20cf934ea..3efe3fa70 100644 --- a/compressor_test.go +++ b/snappy/snappy_test.go @@ -22,7 +22,7 @@ * See the NOTICE file distributed with this work for additional information. */ -package gocql +package snappy import ( "bytes" @@ -40,13 +40,13 @@ func TestSnappyCompressor(t *testing.T) { str := "My Test String" //Test Encoding expected := snappy.Encode(nil, []byte(str)) - if res, err := c.Encode([]byte(str)); err != nil { + if res, err := c.AppendCompressedWithLength(nil, []byte(str)); err != nil { t.Fatalf("failed to encode '%v' with error %v", str, err) } else if bytes.Compare(expected, res) != 0 { t.Fatal("failed to match the expected encoded value with the result encoded value.") } - val, err := c.Encode([]byte(str)) + val, err := c.AppendCompressedWithLength(nil, []byte(str)) if err != nil { t.Fatalf("failed to encode '%v' with error '%v'", str, err) } @@ -54,7 +54,7 @@ func TestSnappyCompressor(t *testing.T) { //Test Decoding if expected, err := snappy.Decode(nil, val); err != nil { t.Fatalf("failed to decode '%v' with error %v", val, err) - } else if res, err := c.Decode(val); err != nil { + } else if res, err := c.AppendDecompressedWithLength(nil, val); err != nil { t.Fatalf("failed to decode '%v' with error %v", val, err) } else if bytes.Compare(expected, res) != 0 { t.Fatal("failed to match the expected decoded value with the result decoded value.")