diff --git a/helpers.go b/helpers.go index 00f339779..cd7cb52a7 100644 --- a/helpers.go +++ b/helpers.go @@ -187,6 +187,48 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { } } +// getCassandraTypeWithVersion is like getCassandraType but set proto in underling Type +func getCassandraTypeWithVersion(name string, logger StdLogger, proto byte) TypeInfo { + if strings.HasPrefix(name, "frozen<") { + return getCassandraTypeWithVersion(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger, proto) + } else if strings.HasPrefix(name, "set<") { + return CollectionType{ + NativeType: NewNativeType(proto, TypeSet, ""), + Elem: getCassandraTypeWithVersion(strings.TrimPrefix(name[:len(name)-1], "set<"), logger, proto), + } + } else if strings.HasPrefix(name, "list<") { + return CollectionType{ + NativeType: NewNativeType(proto, TypeList, ""), + Elem: getCassandraTypeWithVersion(strings.TrimPrefix(name[:len(name)-1], "list<"), logger, proto), + } + } else if strings.HasPrefix(name, "map<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + if len(names) != 2 { + logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) + return NewNativeType(proto, TypeCustom, "") + } + return CollectionType{ + NativeType: NewNativeType(proto, TypeMap, ""), + Key: getCassandraType(names[0], logger), + Elem: getCassandraType(names[1], logger), + } + } else if strings.HasPrefix(name, "tuple<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + types := make([]TypeInfo, len(names)) + + for i, name := range names { + types[i] = getCassandraType(name, logger) + } + + return TupleTypeInfo{ + NativeType: NewNativeType(proto, TypeTuple, ""), + Elems: types, + } + } else { + return NewNativeType(proto, getCassandraBaseType(name), "") + } +} + func splitCompositeTypes(name string) []string { if !strings.Contains(name, "<") { return strings.Split(name, ", ") diff --git a/scylla_shard_aware_extension.go b/scylla_shard_aware_extension.go new file mode 100644 index 000000000..c20d39c67 --- /dev/null +++ b/scylla_shard_aware_extension.go @@ -0,0 +1,237 @@ +package gocql + +import ( + "errors" + "fmt" + "sort" +) + +// ShardAwareRoutingInfo - information about the routing of the request (host and shard on which the request must be made). +// This information will help group requests (or keys) into batches by host and/or shard. + +type ShardAwareRoutingInfo struct { + // RoutingKey - is bytes of primary key + RoutingKey []byte + // Host - is node to connect (HostAware policy) + Host *HostInfo + // Shard - is shard ID of node to connect (ShardAware policy) + Shard int +} + +// GetShardAwareRoutingInfo - identifies a node/shard by PK key. +// +// The driver may not always receive routing information and this is normal. +// In this case, the function will return an error and your application needs to process it normally. +// +// Example for SELECT WHERE IN: +/* + const shardsAbout = 100 // node * (cpu-1) + + // Split []T by chunks + var ( + queryBatches = make(map[string][]T, shardsAbout) // []T grouped by chunks + routingKeys = make(map[string][]byte, shardsAbout) // routing key for query + ) + + for _, pk := range pks { + var ( + shardID string + routingKey []byte + ) + // We receive information about the routing of our keys. + // In this example, PRIMARY KEY consists of one column pk_column_name. + info, err := session.GetShardAwareRoutingInfo(keyspaceName, tableName, []string{"pk_column_name"}, pk) + if err != nil || info.Host == nil { + // We may not get routing information for various reasons (change shema topology, etc). + // It is important to understand the reason when testing (for example, you are not using tokenAwarePolicy) + log.Printf("can't get shard id of pk '%d': %v", pk, err) + } else { + // build key: host + "/" + vShard (127.0.0.1/1) + shardID = info.Host.Hostname() + "/" + strconv.Itoa(info.Shard) + routingKey = info.RoutingKey + } + + // Put key to corresponding batch + batch := queryBatches[shardID] + if batch == nil { + batch = make([]int64, 0, len(pks)/shardsAbout) + } + batch = append(batch, pk) + + queryBatches[shardID] = batch + routingKeys[shardID] = rk + } + + const query = "SELECT * FROM table_name WHERE pk IN (?)" + + var wg sync.WaitGroup + // we go through all the batches to execute queries in parallel + for shard, batch := range batches { + // We divide large batches into smaller chunks, since large batches in SELECT queries have a bad effect on RT scylla + for _, chunk := range slices.ChunkSlice(batch, 10) { // slices.ChunkSlice some function that splits slice by N slices of M or less lenght (in our example M=10) + wg.Add(1) + go func(shard string, chunk []int64) { + defer wg.Done() + + rk := keys[shard] // get our routing key + + scanner := r.session.Query(query, chunk).RoutingKey(rk).Iter().Scanner() // use RoutingKey + + for scanner.Next() { + // ... + } + + if err := scanner.Err(); err != nil { + // ... + } + }(shard, chunk) + } + } + // wait for all answers + wg.Wait() + // NOTE: this is not the most optimal strategy 'cause we're waiting for all queries done. + // If at least one query has long response time it will affects on the response time of our method. (RT our method = max RT of queries) + // The best approach is to build pipeline handling your results using golang channels and so on... +*/ +func (s *Session) GetShardAwareRoutingInfo(table string, colums []string, values ...interface{}) (ShardAwareRoutingInfo, error) { + keyspace := s.cfg.Keyspace + + // fail fast + if len(keyspace) == 0 || len(table) == 0 || len(colums) == 0 || len(values) == 0 { + return ShardAwareRoutingInfo{}, errors.New("missing keyspace, table, columns or values") + } + + // check that host policy is TokenAwareHostPolicy + tokenAwarePolicy, ok := s.policy.(*tokenAwareHostPolicy) + if !ok { + // host policy is not TokenAwareHostPolicy + return ShardAwareRoutingInfo{}, fmt.Errorf("unsupported host policy type %T, must be tokenAwareHostPolicy", s.policy) + } + + // get keyspace metadata + keyspaceMetadata, err := s.KeyspaceMetadata(keyspace) + if err != nil { + return ShardAwareRoutingInfo{}, fmt.Errorf("can't get keyspace %v metadata", keyspace) + } + + // get table metadata + tableMetadata, ok := keyspaceMetadata.Tables[table] + if !ok { + return ShardAwareRoutingInfo{}, fmt.Errorf("table %v metadata not found", table) + } + + // get token metadata + tokenMetadata := tokenAwarePolicy.getMetadataReadOnly() + if tokenMetadata == nil || tokenMetadata.tokenRing == nil { + return ShardAwareRoutingInfo{}, errors.New("can't get token ring metadata") + } + + // get routing key + routingKey, err := getRoutingKey(tableMetadata.PartitionKey, s.connCfg.ProtoVersion, s.logger, colums, values...) + if err != nil { + return ShardAwareRoutingInfo{}, err + } + + // get token from partition key + token := tokenMetadata.tokenRing.partitioner.Hash(routingKey) + mm3token, ok := token.(int64Token) // check if that's murmur3 token + if !ok { + return ShardAwareRoutingInfo{}, fmt.Errorf("unsupported token type %T, must be int64Token", token) + } + + // get hosts by token + var hosts []*HostInfo + if ht := tokenMetadata.replicas[keyspace].replicasFor(mm3token); ht != nil { + hosts = make([]*HostInfo, len(ht.hosts)) + copy(hosts, ht.hosts) // need copy because of later we will sort hosts + } else { + host, _ := tokenMetadata.tokenRing.GetHostForToken(mm3token) + hosts = []*HostInfo{host} + } + + getHostTier := func(h *HostInfo) uint { + if tierer, tiererOk := tokenAwarePolicy.fallback.(HostTierer); tiererOk { // e.g. RackAware + return tierer.HostTier(h) + } else if tokenAwarePolicy.fallback.IsLocal(h) { // e.g. DCAware + return 0 + } else { // e.g. RoundRobin + return 1 + } + } + + // sortable hosts according to the host policy (e.g. local DC places first, then the rest) + sort.Slice(hosts, func(i, j int) bool { + return getHostTier(hosts[i]) < getHostTier(hosts[j]) + }) + + // select host + for _, host := range hosts { + if !host.IsUp() { + // host is not ready to accept our query, skip it + s.logger.Printf("GetShardAwareRoutingInfo: skip host %s: host is not ready", host.Hostname()) + continue + } + + // get host connection pool + pool, ok := s.pool.getPool(host) + if !ok { + s.logger.Printf("GetShardAwareRoutingInfo: skip host %s: can't get host connection pool", host.Hostname()) + continue + } + + // check that connection pool is scylla pool + cp, ok := pool.connPicker.(*scyllaConnPicker) + if !ok { + s.logger.Printf("GetShardAwareRoutingInfo: skip host %s: unsupported connection picker type %T, must be scyllaConnPicker", host.Hostname(), pool.connPicker) + continue + } + + // return Shard Aware info + return ShardAwareRoutingInfo{ + RoutingKey: routingKey, // routing key + Host: host, // host by key (for HostAware policy) + Shard: cp.shardOf(mm3token), // calculate shard id (for ShardAware policy) + }, nil + } + + return ShardAwareRoutingInfo{}, fmt.Errorf("no avilable hosts for token %d", mm3token) +} + +func getRoutingKey( + partitionKey []*ColumnMetadata, + protoVersion int, + logger StdLogger, + columns []string, + values ...interface{}, +) ([]byte, error) { + var ( + indexes = make([]int, len(partitionKey)) + types = make([]TypeInfo, len(partitionKey)) + ) + for keyIndex, keyColumn := range partitionKey { + // set an indicator for checking if the mapping is missing + indexes[keyIndex] = -1 + + // find the column in the query info + for argIndex, boundColumnName := range columns { + if keyColumn.Name == boundColumnName { + // there may be many such bound columns, pick the first + indexes[keyIndex] = argIndex + types[keyIndex] = getCassandraTypeWithVersion(keyColumn.Type, logger, byte(protoVersion)) + break + } + } + + if indexes[keyIndex] == -1 { + // missing a routing key column mapping + // no routing key + return nil, errors.New("missing a routing key column") + } + } + + // create routing key + return createRoutingKey(&routingKeyInfo{ + indexes: indexes, + types: types, + }, values) +} diff --git a/scylla_shard_aware_extension_integration_test.go b/scylla_shard_aware_extension_integration_test.go new file mode 100644 index 000000000..627a204c6 --- /dev/null +++ b/scylla_shard_aware_extension_integration_test.go @@ -0,0 +1,93 @@ +//go:build integration && scylla +// +build integration,scylla + +package gocql + +import ( + "fmt" + "reflect" + "testing" +) + +func TestSession_GetShardAwareRoutingInfo_Integration(t *testing.T) { + const ( + keyspace = "gocql_scylla_shard_aware" + table = "test_column_metadata" + ) + + // prepare + { + cluster := createCluster(func(cc *ClusterConfig) { + cc.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy()) + }) + + session, err := cluster.CreateSession() + if err != nil { + t.Fatalf("failed to create session '%v'", err) + } + + // best practice: add clean up + t.Cleanup(func() { + defer session.Close() // close session after tests + + // clear DB + if err := createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace); err != nil { + t.Logf(fmt.Sprintf("unable to drop keyspace: %v", err)) + } + }) + + err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace) + if err != nil { + t.Fatalf(fmt.Sprintf("unable to drop keyspace: %v", err)) + } + + err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s + WITH replication = { + 'class': 'NetworkTopologyStrategy', + 'replication_factor' : %d + }`, keyspace, *flagRF)) + if err != nil { + t.Fatalf(fmt.Sprintf("unable to create keyspace: %v", err)) + } + + err = createTable(session, fmt.Sprintf("CREATE TABLE %s.%s (first_id int, second_id int, third_id int, PRIMARY KEY ((first_id, second_id)))", keyspace, table)) + if err != nil { + t.Fatalf("failed to create table with error '%v'", err) + } + } + + cluster := createCluster(func(cc *ClusterConfig) { + cc.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy()) + cc.Keyspace = keyspace + }) + + session, err := cluster.CreateSession() + if err != nil { + t.Fatalf("failed to create session '%v'", err) + } + defer session.Close() + + info, err := session.GetShardAwareRoutingInfo(table, []string{"first_id", "second_id", "third_id"}, 1, 2, 3) + if err != nil { + t.Fatalf("failed to get shared aware routing info '%v'", err) + } + + if info.Host == nil { + t.Fatal("empty host info") + } + + // composite key PC key (1,2) + var ( + mask = []byte{0, 4} + delim = byte(0) + // []byte{0, 0, 0, 1} == 1 + // []byte{0, 0, 0, 2} == 2 + want = append(append(append(append(append(mask, []byte{0, 0, 0, 1}...), delim), mask...), []byte{0, 0, 0, 2}...), delim) + ) + + if !reflect.DeepEqual(info.RoutingKey, want) { + t.Fatalf("routing key want: '%v', got: '%v'", want, info.RoutingKey) + } + + t.Logf("shard=%d, hostname=%s", info.Shard, info.Host.hostname) +} diff --git a/scylla_shard_aware_extension_test.go b/scylla_shard_aware_extension_test.go new file mode 100644 index 000000000..34e34fd6b --- /dev/null +++ b/scylla_shard_aware_extension_test.go @@ -0,0 +1,872 @@ +//go:build all || unit +// +build all unit + +package gocql + +import ( + "math" + "reflect" + "sync/atomic" + "testing" +) + +type partitionerMock struct { + t token +} + +func (pm partitionerMock) Name() string { + return "mock" +} + +func (pm partitionerMock) Hash([]byte) token { + return pm.t +} + +func (pm partitionerMock) ParseString(string) token { + return pm.t +} + +func Benchmark_GetShardAwareRoutingInfo(b *testing.B) { + const ( + keyspaceName = "keyspace" + tableName = "table" + partitionKeyName = "pk_column" + host1ID = "host1" + host2ID = "host2" + protoVersion = 4 + ) + + type any = interface{} // remove in go 1.18+ + + tt := struct { + schemaDescriber *schemaDescriber + connCfg *ConnConfig + pool *policyConnPool + policy HostSelectionPolicy + isClosed bool + }{ + policy: &tokenAwareHostPolicy{ + fallback: RoundRobinHostPolicy(), + metadata: func(val any) atomic.Value { + av := atomic.Value{} + av.Store(val) + return av + }(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{ + { + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + }, + { + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host2ID, + state: NodeDown, + }, + }, + }, + }, + replicas: map[string]tokenRingReplicas{ + keyspaceName: { + { + token: scyllaCDCMinToken, + hosts: []*HostInfo{ + { + hostId: host1ID, + state: NodeUp, + }, + { + hostId: host2ID, + state: NodeDown, + }, + }, + }, + }, + }, + }, + ), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: map[string]*KeyspaceMetadata{ + keyspaceName: { + Tables: map[string]*TableMetadata{ + tableName: { + PartitionKey: []*ColumnMetadata{ + { + Name: partitionKeyName, + Type: "int", + }, + }, + }, + }, + }, + }, + }, + pool: &policyConnPool{ + hostConnPools: map[string]*hostConnPool{ + host1ID: { + connPicker: &scyllaConnPicker{ + nrShards: 1, + msbIgnore: 1, + }, + }, + }, + }, + } + + s := &Session{ + schemaDescriber: tt.schemaDescriber, + connCfg: tt.connCfg, + pool: tt.pool, + policy: tt.policy, + isClosed: tt.isClosed, + logger: Logger, + } + s.cfg.Keyspace = keyspaceName + + var ( + columns = []string{partitionKeyName} + values = []interface{}{1} + ) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := s.GetShardAwareRoutingInfo(tableName, columns, values...) + if err != nil { + b.Error(err) + } + } + +} + +func TestSession_GetShardAwareRoutingInfo(t *testing.T) { + type any = interface{} // remove in go 1.18+ + + const ( + keyspaceName = "keyspace" + tableName = "table" + partitionKeyName = "pk_column" + host1ID = "host1" + host2ID = "host2" + localDC = "DC1" + localRack = "Rack1" + nonLocalDC = "DC2" + nonlocalRack = "Rack2" + protoVersion = 4 + ) + var ( + keyspaceMetadata = map[string]*KeyspaceMetadata{ + keyspaceName: { + Tables: map[string]*TableMetadata{ + tableName: { + PartitionKey: []*ColumnMetadata{ + { + Name: partitionKeyName, + Type: "int", + }, + }, + }, + }, + }, + } + + store = func(val any) atomic.Value { + av := atomic.Value{} + av.Store(val) + return av + } + ) + + type fields struct { + schemaDescriber *schemaDescriber + connCfg *ConnConfig + pool *policyConnPool + policy HostSelectionPolicy + isClosed bool + } + type args struct { + keyspace string + table string + primaryKeyColumnNames []string + args []interface{} + } + tests := []struct { + name string + fields fields + args args + want ShardAwareRoutingInfo + wantErr bool + }{ + { + name: "Test 1. empty keyspace", + args: args{ + keyspace: "", + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 2. empty table name", + args: args{ + keyspace: keyspaceName, + table: "", + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 3. empty columns name", + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 4. empty values name", + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{}, + args: []interface{}{1}, + }, + wantErr: true, + }, + + { + name: "Test 5. Not token aware policy", + fields: fields{ + policy: new(dcAwareRR), // not token aware policy + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + + { + name: "Test 6.1. Can't get keyspace metadata", + fields: fields{ + policy: new(tokenAwareHostPolicy), + isClosed: true, // closed session + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + + { + name: "Test 6.2. Can't get keyspace metadata", + fields: fields{ + policy: new(tokenAwareHostPolicy), + schemaDescriber: &schemaDescriber{ + cache: map[string]*KeyspaceMetadata{}, // no keyspace metadata + session: &Session{ + useSystemSchema: false, // failed to get keyspace metadata + }, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 7. Can't get table metadata", + fields: fields{ + policy: new(tokenAwareHostPolicy), + schemaDescriber: &schemaDescriber{ + cache: map[string]*KeyspaceMetadata{ + keyspaceName: { + Tables: map[string]*TableMetadata{}, // no table metadata + }, + }, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 8.1. Can't get token metadata", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: atomic.Value{}, // empty token metadata + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 8.2. Can't get token metadata", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: store(&clusterMeta{ + tokenRing: nil, // no token ring metadata + }), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 9. Missing partition key column", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{{ + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + }, + }}, + }, + }), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{"not_pk_column"}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 10. Can't create routing key", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{{ + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + }, + }}, + }, + }), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{math.MaxInt64}, // Marshal error (int64 to cassandra int) + }, + wantErr: true, + }, + { + name: "Test 11. Not Murmur3Partitioner", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: orderedToken("")}, + }, + }), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 12. Hosts Provider = TokenRing. Host is Down. No coon pool", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{{ + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + state: NodeDown, + }, + }}, + }, + }), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 13. Hosts Provider = TokenRing. Host is Up. Empty coon pool", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{{ + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + }}, + }, + }), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + pool: &policyConnPool{ + hostConnPools: map[string]*hostConnPool{}, // no hosts + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 14. Hosts Provider = TokenRing. Host is Up. Not scylla conn picke", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{{ + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + }}, + }, + }), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + pool: &policyConnPool{ + hostConnPools: map[string]*hostConnPool{ + host1ID: { + connPicker: newDefaultConnPicker(1), // not scylla conn picker + }, + }, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: true, + }, + { + name: "Test 15. Hosts Provider = TokenRing. Host is Up. OK", + fields: fields{ + policy: &tokenAwareHostPolicy{ + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{{ + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + }}, + }, + }), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + pool: &policyConnPool{ + hostConnPools: map[string]*hostConnPool{ + host1ID: { + connPicker: &scyllaConnPicker{ + nrShards: 1, + msbIgnore: 1, + }, + }, + }, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: false, + want: ShardAwareRoutingInfo{ + RoutingKey: []byte{0, 0, 0, 1}, + Host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + Shard: 0, + }, + }, + { + name: "Test 16. Hosts Provider = replicas. DCAware. OK", + fields: fields{ + policy: &tokenAwareHostPolicy{ + fallback: DCAwareRoundRobinPolicy(localDC), + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{ + { + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + }, + { + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host2ID, + state: NodeDown, + }, + }, + }, + }, + replicas: map[string]tokenRingReplicas{ + keyspaceName: { + { + token: scyllaCDCMinToken, + hosts: []*HostInfo{ + { + hostId: host1ID, + state: NodeUp, + dataCenter: nonLocalDC, // up but not in local DC + }, + { + hostId: host2ID, + state: NodeDown, + dataCenter: localDC, // in local DC but not ready + }, + }, + }, + }, + }, + }, + ), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + pool: &policyConnPool{ + hostConnPools: map[string]*hostConnPool{ + host1ID: { + connPicker: &scyllaConnPicker{ + nrShards: 1, + msbIgnore: 1, + }, + }, + }, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: false, + want: ShardAwareRoutingInfo{ + RoutingKey: []byte{0, 0, 0, 1}, + Host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + dataCenter: nonLocalDC, + }, + Shard: 0, + }, + }, + { + name: "Test 17. Hosts Provider = replicas. DCAware. OK", + fields: fields{ + policy: &tokenAwareHostPolicy{ + fallback: RackAwareRoundRobinPolicy(localDC, localRack), + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{ + { + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + }, + { + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host2ID, + state: NodeDown, + }, + }, + }, + }, + replicas: map[string]tokenRingReplicas{ + keyspaceName: { + { + token: scyllaCDCMinToken, + hosts: []*HostInfo{ + { + hostId: host1ID, + state: NodeUp, + dataCenter: localDC, + rack: nonlocalRack, // in local DC and not local rack but host is UP + }, + { + hostId: host2ID, + state: NodeDown, + dataCenter: localDC, + rack: localRack, // in local DC and local rack but not ready + }, + }, + }, + }, + }, + }, + ), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + pool: &policyConnPool{ + hostConnPools: map[string]*hostConnPool{ + host1ID: { + connPicker: &scyllaConnPicker{ + nrShards: 1, + msbIgnore: 1, + }, + }, + }, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: false, + want: ShardAwareRoutingInfo{ + RoutingKey: []byte{0, 0, 0, 1}, + Host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + dataCenter: localDC, + rack: nonlocalRack, + }, + Shard: 0, + }, + }, + { + name: "Test 18. Hosts Provider = replicas. RoundRobin. OK", + fields: fields{ + policy: &tokenAwareHostPolicy{ + fallback: RoundRobinHostPolicy(), + metadata: store(&clusterMeta{ + tokenRing: &tokenRing{ + partitioner: partitionerMock{t: scyllaCDCMinToken}, + tokens: []hostToken{ + { + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + }, + { + token: scyllaCDCMinToken, + host: &HostInfo{ + hostId: host2ID, + state: NodeDown, + }, + }, + }, + }, + replicas: map[string]tokenRingReplicas{ + keyspaceName: { + { + token: scyllaCDCMinToken, + hosts: []*HostInfo{ + { + hostId: host1ID, + state: NodeUp, + }, + { + hostId: host2ID, + state: NodeDown, + }, + }, + }, + }, + }, + }, + ), + }, + connCfg: &ConnConfig{ + ProtoVersion: protoVersion, // no panic in marshal + }, + schemaDescriber: &schemaDescriber{ + cache: keyspaceMetadata, + }, + pool: &policyConnPool{ + hostConnPools: map[string]*hostConnPool{ + host1ID: { + connPicker: &scyllaConnPicker{ + nrShards: 1, + msbIgnore: 1, + }, + }, + }, + }, + }, + args: args{ + keyspace: keyspaceName, + table: tableName, + primaryKeyColumnNames: []string{partitionKeyName}, + args: []interface{}{1}, + }, + wantErr: false, + want: ShardAwareRoutingInfo{ + RoutingKey: []byte{0, 0, 0, 1}, + Host: &HostInfo{ + hostId: host1ID, + state: NodeUp, + }, + Shard: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{ + schemaDescriber: tt.fields.schemaDescriber, + connCfg: tt.fields.connCfg, + pool: tt.fields.pool, + policy: tt.fields.policy, + isClosed: tt.fields.isClosed, + logger: Logger, + } + s.cfg.Keyspace = tt.args.keyspace + + got, err := s.GetShardAwareRoutingInfo(tt.args.table, tt.args.primaryKeyColumnNames, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("Session.GetShardAwareRoutingInfo() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Session.GetShardAwareRoutingInfo() = %v, want %v", got, tt.want) + } + }) + } +}