Skip to content

Commit

Permalink
SetHost API for Query
Browse files Browse the repository at this point in the history
Query.SetHost() allows users to specify on which node the Query will be executed.
It is not a tipycal use case, but it makes sense with virtual tables which are available since C* 5.0.0.

Patch by Bohdan Siryk; Reviewed by João Reis for CASSGO-4
  • Loading branch information
worryg0d committed Nov 26, 2024
1 parent 37030fb commit 63c6f5d
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- SetHost API for Query (CASSGO-4)

### Changed

- Don't restrict server authenticator unless PasswordAuthentictor.AllowedAuthenticators is provided (CASSGO-19)
Expand Down
47 changes: 45 additions & 2 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"context"
"errors"
"fmt"
"github.com/stretchr/testify/require"
"io"
"math"
"math/big"
Expand All @@ -45,6 +44,8 @@ import (
"time"
"unicode"

"github.com/stretchr/testify/require"

"gopkg.in/inf.v0"
)

Expand Down Expand Up @@ -3303,7 +3304,6 @@ func TestUnsetColBatch(t *testing.T) {
}
var id, mInt, count int
var mText string

if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil {
t.Fatalf("Failed to select with err: %v", err)
} else if count != 2 {
Expand Down Expand Up @@ -3338,3 +3338,46 @@ func TestQuery_NamedValues(t *testing.T) {
t.Fatal(err)
}
}

func TestQuery_SetHost(t *testing.T) {
// This test ensures that queries are sent to the specified host only

session := createSession(t)
defer session.Close()

hosts, err := session.GetHosts()
if err != nil {
t.Fatal(err)
}

for _, expectedHost := range hosts {
const iterations = 5
for i := 0; i < iterations; i++ {
var actualHostID string
err := session.Query("SELECT host_id FROM system.local").
SetHost(expectedHost).
Scan(&actualHostID)
if err != nil {
t.Fatal(err)
}

if expectedHost.HostID() != actualHostID {
t.Fatalf("Expected query to be executed on host %s, but it was executed on %s",
expectedHost.HostID(),
actualHostID,
)
}
}
}

// ensuring that the driver properly handles the case
// when specified host for the query is down
host := hosts[0]
host.state = NodeDown
err = session.Query("SELECT host_id FROM system.local").
SetHost(host).
Exec()
if !errors.Is(err, ErrNoConnections) {
t.Fatalf("Expected error to be: %v, but got %v", ErrNoConnections, err)
}
}
31 changes: 27 additions & 4 deletions query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type ExecutableQuery interface {
Keyspace() string
Table() string
IsIdempotent() bool
GetHost() *HostInfo

withContext(context.Context) ExecutableQuery

Expand Down Expand Up @@ -83,12 +84,27 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
}

func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
hostIter := q.policy.Pick(qry)
var hostIter NextHost

// checking if the host is specified for the query,
// if it is, the query should be executed at the specified host
host := qry.GetHost()
if host != nil {
hostIter = func() SelectedHost {
return (*selectedHost)(host)
}
}

// if host is not specified for the query,
// then a host will be picked by HostSelectionPolicy
if hostIter == nil {
hostIter = q.policy.Pick(qry)
}

// check if the query is not marked as idempotent, if
// it is, we force the policy to NonSpeculative
sp := qry.speculativeExecutionPolicy()
if !qry.IsIdempotent() || sp.Attempts() == 0 {
if host != nil || !qry.IsIdempotent() || sp.Attempts() == 0 {
return q.do(qry.Context(), qry, hostIter), nil
}

Expand Down Expand Up @@ -129,12 +145,17 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter {
selectedHost := hostIter()
rt := qry.retryPolicy()
specifiedHost := qry.GetHost()

var lastErr error
var iter *Iter
for selectedHost != nil {
host := selectedHost.Info()
if host == nil || !host.IsUp() {
if specifiedHost != nil && host != nil && !host.IsUp() {
return &Iter{err: ErrNoConnections}
}

if (host == nil || !host.IsUp()) && specifiedHost == nil {
selectedHost = hostIter()
continue
}
Expand Down Expand Up @@ -166,7 +187,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne

// Exit if the query was successful
// or query is not idempotent or no retry policy defined
if iter.err == nil || !qry.IsIdempotent() || rt == nil {
// Also, if there is specified host for the query to be executed on
// and query execution is failed we should exit
if iter.err == nil || specifiedHost != nil || !qry.IsIdempotent() || rt == nil {
return iter
}

Expand Down
29 changes: 29 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,10 @@ type Query struct {

// routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex.
routingInfo *queryRoutingInfo

// host specifies the host on which the query should be executed.
// If it is nil, then the host is picked by HostSelectionPolicy
host *HostInfo
}

type queryRoutingInfo struct {
Expand Down Expand Up @@ -1430,6 +1434,18 @@ func (q *Query) releaseAfterExecution() {
q.decRefCount()
}

// SetHosts allows to define on which host the query should be executed.
// If host == nil, then the HostSelectionPolicy will be used to pick a host.
func (q *Query) SetHost(host *HostInfo) *Query {
q.host = host
return q
}

// GetHost returns host on which query should be executed.
func (q *Query) GetHost() *HostInfo {
return q.host
}

// Iter represents an iterator that can be used to iterate over all rows that
// were returned by a query. The iterator might send additional queries to the
// database during the iteration if paging was enabled.
Expand Down Expand Up @@ -2045,6 +2061,10 @@ func (b *Batch) releaseAfterExecution() {
// that would race with speculative executions.
}

func (b *Batch) GetHost() *HostInfo {
return nil
}

type BatchType byte

const (
Expand Down Expand Up @@ -2177,6 +2197,15 @@ func (t *traceWriter) Trace(traceId []byte) {
}
}

// GetHosts returns a list of hosts found via queries to system.local and system.peers
func (s *Session) GetHosts() ([]*HostInfo, error) {
hosts, _, err := s.hostSource.GetHosts()
if err != nil {
return nil, err
}
return hosts, nil
}

type ObservedQuery struct {
Keyspace string
Statement string
Expand Down

0 comments on commit 63c6f5d

Please sign in to comment.