Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Schmidt committed Sep 20, 2024
1 parent 12d951f commit b8051cf
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 54 deletions.
61 changes: 8 additions & 53 deletions connection_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"log/slog"
"math/rand"
"sort"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -450,20 +449,20 @@ func (cp *connectionProvider) checkAndSetClusterID(clusterID uint64) bool {
}

// getTendConns returns all the gRPC client connections for tend operations.
func (cp *connectionProvider) getTendConns() []grpcClientConn {
func (cp *connectionProvider) getTendConns() []*connection {
cp.nodeConnsLock.RLock()
defer cp.nodeConnsLock.RUnlock()

conns := make([]grpcClientConn, len(cp.seedConns)+len(cp.nodeConns))
conns := make([]*connection, len(cp.seedConns)+len(cp.nodeConns))
i := 0

for _, conn := range cp.seedConns {
conns[i] = conn.grpcConn
conns[i] = conn
i++
}

for _, conn := range cp.nodeConns {
conns[i] = conn.conn.grpcConn
conns[i] = conn.conn
i++
}

Expand All @@ -480,13 +479,12 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6
for _, conn := range conns {
wg.Add(1)

go func(conn grpcClientConn) {
go func(conn *connection) {
defer wg.Done()

logger := cp.logger.With(slog.String("host", conn.Target()))
client := protos.NewClusterInfoServiceClient(conn)
logger := cp.logger.With(slog.String("host", conn.grpcConn.Target()))

clusterID, err := client.GetClusterId(ctx, &emptypb.Empty{})
clusterID, err := conn.clusterInfoClient.GetClusterId(ctx, &emptypb.Empty{})
if err != nil {
logger.WarnContext(ctx, "failed to get cluster ID", slog.Any("error", err))
}
Expand All @@ -503,7 +501,7 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6

logger.DebugContext(ctx, "new cluster ID found", slog.Uint64("clusterID", clusterID.GetId()))

endpointsResp, err := client.GetClusterEndpoints(ctx, endpointsReq)
endpointsResp, err := conn.clusterInfoClient.GetClusterEndpoints(ctx, endpointsReq)
if err != nil {
logger.ErrorContext(ctx, "failed to get cluster endpoints", slog.Any("error", err))
return
Expand Down Expand Up @@ -644,49 +642,6 @@ func (cp *connectionProvider) tend(ctx context.Context) {
}
}

func endpointEqual(a, b *protos.ServerEndpoint) bool {
return a.Address == b.Address && a.Port == b.Port && a.IsTls == b.IsTls
}

func endpointListEqual(a, b *protos.ServerEndpointList) bool {
if len(a.Endpoints) != len(b.Endpoints) {
return false
}

aEndpoints := make([]*protos.ServerEndpoint, len(a.Endpoints))
copy(aEndpoints, a.Endpoints)

bEndpoints := make([]*protos.ServerEndpoint, len(b.Endpoints))
copy(bEndpoints, b.Endpoints)

sortFunc := func(endpoints []*protos.ServerEndpoint) func(int, int) bool {
return func(i, j int) bool {
if endpoints[i].Address < endpoints[j].Address {
return true
} else if endpoints[i].Address > endpoints[j].Address {
return false
}

return endpoints[i].Port < endpoints[j].Port
}
}

sort.Slice(aEndpoints, sortFunc(aEndpoints))
sort.Slice(bEndpoints, sortFunc(bEndpoints))

for i, endpoint := range aEndpoints {
if !endpointEqual(endpoint, bEndpoints[i]) {
return false
}
}

return true
}

func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort {
return NewHostPort(endpoint.Address, int(endpoint.Port))
}

// createGrpcConnFromEndpoints creates a gRPC client connection from the first
// successful endpoint in endpoints.
func (cp *connectionProvider) createGrpcConnFromEndpoints(
Expand Down
221 changes: 221 additions & 0 deletions connection_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -168,3 +169,223 @@ func TestGetSeedConn_FailSeedConnEmpty(t *testing.T) {

assert.Equal(t, errors.New("no seed connections found"), err)
}
func TestUpdateClusterConns_NoNewClusterID(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx := context.Background()

cp := &connectionProvider{
logger: slog.Default(),
nodeConns: make(map[uint64]*connectionAndEndpoints),
seedConns: []*connection{},
tlsConfig: nil,
seeds: HostPortSlice{},
nodeConnsLock: &sync.RWMutex{},
tendInterval: time.Second * 1,
clusterID: 123,
listenerName: nil,
isLoadBalancer: false,
token: nil,
stopTendChan: make(chan struct{}),
closed: atomic.Bool{},
}

cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NoNewClusterID"))

cp.logger.Debug("Setting up existing node connections")

grpcConn1 := NewMockgrpcClientConn(ctrl)
mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl)
grpcConn2 := NewMockgrpcClientConn(ctrl)
mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl)

grpcConn1.
EXPECT().
Target().
Return("")

mockClusterInfoClient1.
EXPECT().
GetClusterId(gomock.Any(), gomock.Any()).
Return(&protos.ClusterId{
Id: 123,
}, nil)

grpcConn2.
EXPECT().
Target().
Return("")

mockClusterInfoClient2.
EXPECT().
GetClusterId(gomock.Any(), gomock.Any()).
Return(&protos.ClusterId{
Id: 123,
}, nil)

// Existing node connections
cp.nodeConns[1] = &connectionAndEndpoints{
conn: &connection{
grpcConn: grpcConn1,
clusterInfoClient: mockClusterInfoClient1,
},
endpoints: &protos.ServerEndpointList{},
}

cp.nodeConns[2] = &connectionAndEndpoints{
conn: &connection{
grpcConn: grpcConn2,
clusterInfoClient: mockClusterInfoClient2,
},
endpoints: &protos.ServerEndpointList{},
}

cp.logger.Debug("Running updateClusterConns")

cp.updateClusterConns(ctx)

assert.Equal(t, uint64(123), cp.clusterID)
assert.Len(t, cp.nodeConns, 2)
}

// func TestUpdateClusterConns_NewClusterID(t *testing.T) {
// ctrl := gomock.NewController(t)
// defer ctrl.Finish()

// ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
// defer cancel()

// cp := &connectionProvider{
// logger: slog.Default(),
// nodeConns: make(map[uint64]*connectionAndEndpoints),
// seedConns: []*connection{},
// tlsConfig: &tls.Config{},
// seeds: HostPortSlice{},
// nodeConnsLock: &sync.RWMutex{},
// tendInterval: time.Second * 1,
// clusterID: 123,
// listenerName: nil,
// isLoadBalancer: false,
// token: nil,
// stopTendChan: make(chan struct{}),
// closed: atomic.Bool{},
// }

// cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID"))

// cp.logger.Debug("Setting up existing node connections")

// grpcConn1 := NewMockgrpcClientConn(ctrl)
// mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl)
// grpcConn2 := NewMockgrpcClientConn(ctrl)
// mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl)

// grpcConn1.
// EXPECT().
// Target().
// Return("")

// mockClusterInfoClient1.
// EXPECT().
// GetClusterId(gomock.Any(), gomock.Any()).
// Return(&protos.ClusterId{
// Id: 123,
// }, nil)

// grpcConn1.
// EXPECT().
// Close().
// Return(nil)

// grpcConn2.
// EXPECT().
// Target().
// Return("")

// mockClusterInfoClient2.
// EXPECT().
// GetClusterId(gomock.Any(), gomock.Any()).
// Return(&protos.ClusterId{
// Id: 456,
// }, nil)

// mockClusterInfoClient2.
// EXPECT().
// GetClusterEndpoints(gomock.Any(), gomock.Any()).
// Return(&protos.ClusterNodeEndpoints{
// Endpoints: map[uint64]*protos.ServerEndpointList{
// 3: {
// Endpoints: []*protos.ServerEndpoint{
// {
// Address: "1.1.1.1",
// Port: 3000,
// },
// },
// },
// 4: {
// Endpoints: []*protos.ServerEndpoint{
// {
// Address: "2.2.2.2",
// Port: 3000,
// },
// },
// },
// },
// }, nil)

// grpcConn2.
// EXPECT().
// Close().
// Return(nil)

// // Existing node connections
// cp.nodeConns[1] = &connectionAndEndpoints{
// conn: &connection{
// grpcConn: grpcConn1,
// clusterInfoClient: mockClusterInfoClient1,
// },
// endpoints: &protos.ServerEndpointList{},
// }

// cp.nodeConns[2] = &connectionAndEndpoints{
// conn: &connection{
// grpcConn: grpcConn2,
// clusterInfoClient: mockClusterInfoClient2,
// },
// endpoints: &protos.ServerEndpointList{},
// }

// cp.logger.Debug("Running updateClusterConns")

// // New cluster ID
// // newEndpoints := &protos.ServerEndpointList{
// // Endpoints: []*protos.ServerEndpoint{
// // {
// // Address: "host1",
// // Port: 3000,
// // },
// // {
// // Address: "host2",
// // Port: 3000,
// // },
// // },
// // }

// cp.updateClusterConns(ctx)

// // cp.checkAndSetNodeConns(ctx, map[uint64]*protos.ServerEndpointList{
// // 1: newEndpoints,
// // 2: newEndpoints,
// // })

// // cp.removeDownNodes(map[uint64]*protos.ServerEndpointList{
// // 1: newEndpoints,
// // 2: newEndpoints,
// // })

// // cp.updateClusterConns(ctx)

// assert.Equal(t, uint64(456), cp.clusterID)
// assert.Len(t, cp.nodeConns, 2)
// }
44 changes: 44 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package avs

import (
"sort"
"strconv"
"strings"

Expand Down Expand Up @@ -74,6 +75,49 @@ func createIndexStatusRequest(namespace, name string) *protos.IndexStatusRequest
}
}

func endpointEqual(a, b *protos.ServerEndpoint) bool {
return a.Address == b.Address && a.Port == b.Port && a.IsTls == b.IsTls
}

func endpointListEqual(a, b *protos.ServerEndpointList) bool {
if len(a.Endpoints) != len(b.Endpoints) {
return false
}

aEndpoints := make([]*protos.ServerEndpoint, len(a.Endpoints))
copy(aEndpoints, a.Endpoints)

bEndpoints := make([]*protos.ServerEndpoint, len(b.Endpoints))
copy(bEndpoints, b.Endpoints)

sortFunc := func(endpoints []*protos.ServerEndpoint) func(int, int) bool {
return func(i, j int) bool {
if endpoints[i].Address < endpoints[j].Address {
return true
} else if endpoints[i].Address > endpoints[j].Address {
return false
}

return endpoints[i].Port < endpoints[j].Port
}
}

sort.Slice(aEndpoints, sortFunc(aEndpoints))
sort.Slice(bEndpoints, sortFunc(bEndpoints))

for i, endpoint := range aEndpoints {
if !endpointEqual(endpoint, bEndpoints[i]) {
return false
}
}

return true
}

func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort {
return NewHostPort(endpoint.Address, int(endpoint.Port))
}

var minimumFullySupportedAVSVersion = newVersion("0.10.0")

type version []any
Expand Down
Loading

0 comments on commit b8051cf

Please sign in to comment.