Skip to content

Commit

Permalink
Merge pull request #206 from sylwiaszunejko/dc_name_validation
Browse files Browse the repository at this point in the history
Add the datacenter name validation if provided
  • Loading branch information
sylwiaszunejko authored Jul 5, 2024
2 parents 0bd6283 + ec64325 commit ed9f13a
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 4 deletions.
48 changes: 48 additions & 0 deletions policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ type HostSelectionPolicy interface {
// so it's safe to have internal state without additional synchronization as long as every call to Pick returns
// a different instance of NextHost.
Pick(ExecutableQuery) NextHost
// IsOperational checks if host policy can properly work with given Session/Cluster/ClusterConfig
IsOperational(*Session) error
}

// SelectedHost is an interface returned when picking a host from a host
Expand Down Expand Up @@ -363,6 +365,7 @@ func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {}
func (r *roundRobinHostPolicy) Init(*Session) {}
func (r *roundRobinHostPolicy) Reset() {}
func (r *roundRobinHostPolicy) IsOperational(*Session) error { return nil }

// Experimental, this interface and use may change
func (r *roundRobinHostPolicy) SetTablets(tablets []*TabletInfo) {}
Expand Down Expand Up @@ -489,6 +492,10 @@ func (t *tokenAwareHostPolicy) Reset() {
t.logger = nil
}

func (t *tokenAwareHostPolicy) IsOperational(session *Session) error {
return t.fallback.IsOperational(session)
}

func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
return t.fallback.IsLocal(host)
}
Expand Down Expand Up @@ -823,6 +830,7 @@ type hostPoolHostPolicy struct {

func (r *hostPoolHostPolicy) Init(*Session) {}
func (r *hostPoolHostPolicy) Reset() {}
func (r *hostPoolHostPolicy) IsOperational(*Session) error { return nil }
func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (r *hostPoolHostPolicy) SetPartitioner(string) {}
func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true }
Expand Down Expand Up @@ -984,6 +992,27 @@ func (d *dcAwareRR) Reset() {}
func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *dcAwareRR) SetPartitioner(p string) {}

func (d *dcAwareRR) IsOperational(session *Session) error {
if session.cfg.disableInit || session.cfg.disableControlConn {
return nil
}

hosts, _, err := session.hostSource.GetHosts()
if err != nil {
return fmt.Errorf("gocql: unable to check if session is operational: %v", err)
}
for _, host := range hosts {
if !session.cfg.filterHost(host) && host.DataCenter() == d.local {
// Policy can work properly only if there is at least one host from target DC
// No need to check host status, since it could be down due to the outage
// We only need to make sure that policy is not misconfigured with wrong DC
return nil
}
}

return fmt.Errorf("gocql: datacenter %s in the policy was not found in the topology - probable DC aware policy misconfiguration", d.local)
}

func (d *dcAwareRR) IsLocal(host *HostInfo) bool {
return host.DataCenter() == d.local
}
Expand Down Expand Up @@ -1088,6 +1117,25 @@ func (d *rackAwareRR) Reset() {}
func (d *rackAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
func (d *rackAwareRR) SetPartitioner(p string) {}

func (d *rackAwareRR) IsOperational(session *Session) error {
if session.cfg.disableInit || session.cfg.disableControlConn {
return nil
}
hosts, _, err := session.hostSource.GetHosts()
if err != nil {
return fmt.Errorf("gocql: unable to check if session is operational: %v", err)
}
for _, host := range hosts {
if !session.cfg.filterHost(host) && host.DataCenter() == d.localDC && host.Rack() == d.localRack {
// Policy can work properly only if there is at least one host from target DC+Rack
// No need to check host status, since it could be down due to the outage
// We only need to make sure that policy is not misconfigured with wrong DC+Rack
return nil
}
}
return fmt.Errorf("gocql: rack %s/%s was not found in the topology - probable Rack aware policy misconfiguration", d.localDC, d.localRack)
}

func (d *rackAwareRR) MaxHostTier() uint {
return 2
}
Expand Down
41 changes: 41 additions & 0 deletions policies_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build integration && scylla
// +build integration,scylla

package gocql

import (
"testing"
)

// Check if session fail to start if DC name provided in the policy is wrong
func TestDCValidationTokenAware(t *testing.T) {
cluster := createCluster()

fallback := DCAwareRoundRobinPolicy("WRONG_DC")
cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback)

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}

func TestDCValidationDCAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = DCAwareRoundRobinPolicy("WRONG_DC")

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}

func TestDCValidationRackAware(t *testing.T) {
cluster := createCluster()
cluster.PoolConfig.HostSelectionPolicy = RackAwareRoundRobinPolicy("WRONG_DC", "RACK")

_, err := cluster.CreateSession()
if err == nil {
t.Fatal("createSession was expected to fail with wrong DC name provided.")
}
}
17 changes: 13 additions & 4 deletions policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ func TestHostPolicy_DCAwareRR(t *testing.T) {

}

func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) {
p := DCAwareRoundRobinPolicy("wrong_dc", HostPolicyOptionDisableDCFailover)
func TestHostPolicy_DCAwareRR_disableDCFailover(t *testing.T) {
p := DCAwareRoundRobinPolicy("local", HostPolicyOptionDisableDCFailover)

hosts := [...]*HostInfo{
{hostId: "0", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "local"},
Expand All @@ -616,19 +616,28 @@ func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) {
}

got := make(map[string]bool, len(hosts))
var dcs []string

it := p.Pick(nil)
for h := it(); h != nil; h = it() {
id := h.Info().hostId
dc := h.Info().dataCenter

if got[id] {
t.Fatalf("got duplicate host %s", id)
}
got[id] = true
dcs = append(dcs, dc)
}

if len(got) != 0 {
t.Fatalf("expected %d hosts got %d", 0, len(got))
if len(got) != 2 {
t.Fatalf("expected %d hosts got %d", 2, len(got))
}

for _, dc := range dcs {
if dc == "remote" {
t.Fatalf("got remote dc but failover was diabled")
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
}
}

if s.policy.IsOperational(s) != nil {
return nil, fmt.Errorf("gocql: unable to create session: %v", err)
}

return s, nil
}

Expand Down

0 comments on commit ed9f13a

Please sign in to comment.