From 16f3ba1a8104d09960ebdefaecfb4a78b98237f5 Mon Sep 17 00:00:00 2001 From: Aditya Dani Date: Tue, 19 Dec 2023 09:08:50 -0800 Subject: [PATCH] PWX-35430: Honor cluster domains in round-robin balancer. (#2390) - When the cluster is running with cluster-domains enabled, the round robin balancer should choose the nodes to forward the request within the same cluster domain. Signed-off-by: Aditya Dani --- api/api.go | 3 + pkg/loadbalancer/balancer.go | 8 +++ pkg/loadbalancer/mock/balancer.go | 16 +++++ pkg/loadbalancer/roundrobin.go | 46 +++++++++++--- pkg/loadbalancer/roundrobin_test.go | 95 +++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+), 10 deletions(-) create mode 100644 pkg/loadbalancer/roundrobin_test.go diff --git a/api/api.go b/api/api.go index 6531867ee..202bff52e 100644 --- a/api/api.go +++ b/api/api.go @@ -284,6 +284,8 @@ type Node struct { SchedulerTopology *SchedulerTopology // Flag indicating whether the node is a quorum member or not NonQuorumMember bool + // DomainID is the ID of the cluster domain to which this node belongs to. + DomainID string } // FluentDConfig describes ip and port of a fluentdhost. @@ -1321,6 +1323,7 @@ func (v *VolumeSpec) IsPureVolume() bool { func (v *VolumeSpec) IsPureBlockVolume() bool { return v.GetProxySpec() != nil && v.GetProxySpec().IsPureBlockBackend() } + // GetCloneCreatorOwnership returns the appropriate ownership for the // new snapshot and if an update is required func (v *VolumeSpec) GetCloneCreatorOwnership(ctx context.Context) (*Ownership, bool) { diff --git a/pkg/loadbalancer/balancer.go b/pkg/loadbalancer/balancer.go index 8e2b2f464..a7b557a59 100644 --- a/pkg/loadbalancer/balancer.go +++ b/pkg/loadbalancer/balancer.go @@ -16,6 +16,10 @@ type Balancer interface { // The boolean return argument is set to false if the connection is created // to the local node. GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) + // GetRemoteNode returns the node ID of the node to which the next remote + // connection will be created. The boolean return argument is set to false + // if the connection is created to the local node. + GetRemoteNode() (string, bool, error) } type nullBalancer struct{} @@ -28,3 +32,7 @@ func NewNullBalancer() Balancer { func (n *nullBalancer) GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) { return nil, false, fmt.Errorf("remote connections not supported") } + +func (n *nullBalancer) GetRemoteNode() (string, bool, error) { + return "", false, fmt.Errorf("remote connections not supported") +} diff --git a/pkg/loadbalancer/mock/balancer.go b/pkg/loadbalancer/mock/balancer.go index 31b795f22..72b842e3d 100644 --- a/pkg/loadbalancer/mock/balancer.go +++ b/pkg/loadbalancer/mock/balancer.go @@ -35,6 +35,22 @@ func (m *MockBalancer) EXPECT() *MockBalancerMockRecorder { return m.recorder } +// GetRemoteNode mocks base method. +func (m *MockBalancer) GetRemoteNode() (string, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRemoteNode") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetRemoteNode indicates an expected call of GetRemoteNode. +func (mr *MockBalancerMockRecorder) GetRemoteNode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRemoteNode", reflect.TypeOf((*MockBalancer)(nil).GetRemoteNode)) +} + // GetRemoteNodeConnection mocks base method. func (m *MockBalancer) GetRemoteNodeConnection(arg0 context.Context) (*grpc.ClientConn, bool, error) { m.ctrl.T.Helper() diff --git a/pkg/loadbalancer/roundrobin.go b/pkg/loadbalancer/roundrobin.go index 03885db60..70c255765 100644 --- a/pkg/loadbalancer/roundrobin.go +++ b/pkg/loadbalancer/roundrobin.go @@ -64,23 +64,49 @@ func NewRoundRobinBalancer( return rr, nil } -func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) { +func (rr *roundRobin) GetRemoteNode() (string, bool, error) { // Get all nodes and sort them cluster, err := rr.cluster.Enumerate() if err != nil { - return nil, false, err + return "", false, err } if len(cluster.Nodes) < 1 { - return nil, false, errors.New("cluster nodes for remote connection not found") + return "", false, errors.New("cluster nodes for remote connection not found") + } + // Get our node object + selfNode, err := rr.cluster.Inspect(cluster.NodeId) + if err != nil { + return "", false, err + } + var filteredNodes []*api.Node + + if selfNode.DomainID != "" { + // Filter out nodes from a different cluster domain. + for _, node := range cluster.Nodes { + if selfNode.DomainID == node.DomainID { + filteredNodes = append(filteredNodes, node) + } + } + } else { + filteredNodes = cluster.Nodes } - sort.Slice(cluster.Nodes, func(i, j int) bool { - return cluster.Nodes[i].Id < cluster.Nodes[j].Id + + sort.Slice(filteredNodes, func(i, j int) bool { + return filteredNodes[i].Id < filteredNodes[j].Id }) // Get target node info and set next round robbin node. // nextNode is always lastNode + 1 mod (numOfNodes), to loop back to zero - targetNodeEndpoint, isRemoteConn := rr.getTargetAndIncrement(&cluster) + targetNodeEndpoint, isRemoteConn := rr.getTargetAndIncrement(filteredNodes, selfNode.Id) + + return targetNodeEndpoint, isRemoteConn, nil +} +func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) { + targetNodeEndpoint, isRemoteConn, err := rr.GetRemoteNode() + if err != nil { + return nil, false, err + } // Get conn for this node, otherwise create new conn timedSDKConn, ok := rr.getNodeConnection(targetNodeEndpoint) if !ok { @@ -109,7 +135,7 @@ func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.Client } -func (rr *roundRobin) getTargetAndIncrement(cluster *api.Cluster) (string, bool) { +func (rr *roundRobin) getTargetAndIncrement(filteredNodes []*api.Node, selfNodeID string) (string, bool) { rr.mu.Lock() defer rr.mu.Unlock() var ( @@ -119,14 +145,14 @@ func (rr *roundRobin) getTargetAndIncrement(cluster *api.Cluster) (string, bool) if rr.nextCreateNodeNumber != 0 { targetNodeNumber = rr.nextCreateNodeNumber } - targetNode := cluster.Nodes[targetNodeNumber] - if targetNode.Id != cluster.NodeId { + targetNode := filteredNodes[targetNodeNumber] + if targetNode.Id != selfNodeID { // NodeID set on the cluster object is this node's ID. // Target NodeID does not match with our NodeID, so this will be a remote connection. isRemoteConn = true } targetNodeEndpoint := targetNode.MgmtIp - rr.nextCreateNodeNumber = (targetNodeNumber + 1) % len(cluster.Nodes) + rr.nextCreateNodeNumber = (targetNodeNumber + 1) % len(filteredNodes) return targetNodeEndpoint, isRemoteConn } diff --git a/pkg/loadbalancer/roundrobin_test.go b/pkg/loadbalancer/roundrobin_test.go new file mode 100644 index 000000000..8199bf9a6 --- /dev/null +++ b/pkg/loadbalancer/roundrobin_test.go @@ -0,0 +1,95 @@ +package loadbalancer + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/libopenstorage/openstorage/api" + "github.com/libopenstorage/openstorage/cluster/mock" + "github.com/libopenstorage/openstorage/pkg/sched" + "github.com/stretchr/testify/require" +) + +var ( + ips = []string{"127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4", "127.0.0.5", "127.0.0.6"} + ids = []string{"1", "2", "3", "4", "5", "6"} +) + +func getMockClusterResponse(enableClusterDomain bool) *api.Cluster { + nodes := make([]*api.Node, 0) + for i := 0; i < len(ips); i++ { + node := &api.Node{ + MgmtIp: ips[i], + Id: ids[i], + } + if enableClusterDomain { + if i%2 == 0 { + node.DomainID = "domain1" + } else { + node.DomainID = "domain2" + } + } + nodes = append(nodes, node) + } + return &api.Cluster{ + NodeId: ids[0], // self node + Nodes: nodes, + } +} + +func TestGetRemoteNodeWithoutDomains(t *testing.T) { + if sched.Instance() == nil { + sched.Init(time.Second) + } + ctrl := gomock.NewController(t) + cc := mock.NewMockCluster(ctrl) + rr, err := NewRoundRobinBalancer(cc, "1234") + require.NoError(t, err, "failed to create round robin balancer") + + cc.EXPECT().Enumerate().Return(*getMockClusterResponse(false), nil).AnyTimes() + cc.EXPECT().Inspect(ids[0]).Return(api.Node{MgmtIp: ips[0], Id: ids[0]}, nil).AnyTimes() + + for loop := 0; loop < 2; loop++ { + targetNode, isRemoteConn, err := rr.GetRemoteNode() + require.NoError(t, err, "failed to get remote node") + require.Equal(t, targetNode, ips[0], "target node is not as expected") + require.False(t, isRemoteConn, "isRemoteConn is not as expected") + + for i := 1; i < len(ips); i++ { + targetNode, isRemoteConn, err := rr.GetRemoteNode() + require.NoError(t, err, "failed to get remote node") + require.Equal(t, targetNode, ips[i], "target node is not as expected") + require.True(t, isRemoteConn, "isRemoteConn is not as expected") + + } + } +} + +func TestGetRemoteNodeWithDomains(t *testing.T) { + if sched.Instance() == nil { + sched.Init(time.Second) + } + ctrl := gomock.NewController(t) + cc := mock.NewMockCluster(ctrl) + rr, err := NewRoundRobinBalancer(cc, "1234") + require.NoError(t, err, "failed to create round robin balancer") + + cc.EXPECT().Enumerate().Return(*getMockClusterResponse(true), nil).AnyTimes() + cc.EXPECT().Inspect(ids[0]).Return(api.Node{MgmtIp: ips[0], Id: ids[0], DomainID: "domain1"}, nil).AnyTimes() + + for loop := 0; loop < 2; loop++ { + targetNode, isRemoteConn, err := rr.GetRemoteNode() + require.NoError(t, err, "failed to get remote node") + require.Equal(t, targetNode, ips[0], "target node is not as expected") + require.False(t, isRemoteConn, "isRemoteConn is not as expected") + + for i := 1; i < 3; i++ { + targetNode, isRemoteConn, err := rr.GetRemoteNode() + require.NoError(t, err, "failed to get remote node") + require.Equal(t, targetNode, ips[i*2], "target node is not as expected") + require.True(t, isRemoteConn, "isRemoteConn is not as expected") + + } + } +}