diff --git a/pkg/loadbalancer/roundrobin.go b/pkg/loadbalancer/roundrobin.go index 77c3848ca..7c606a7a2 100644 --- a/pkg/loadbalancer/roundrobin.go +++ b/pkg/loadbalancer/roundrobin.go @@ -27,7 +27,7 @@ type roundRobin struct { cluster cluster.Cluster connMap map[string]*TimedSDKConn nextCreateNodeNumber int - mu sync.Mutex + mu sync.RWMutex grpcServerPort string } @@ -65,9 +65,6 @@ func NewRoundRobinBalancer( } func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) { - rr.mu.Lock() - defer rr.mu.Unlock() - // Get all nodes and sort them cluster, err := rr.cluster.Enumerate() if err != nil { @@ -85,6 +82,39 @@ func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.Client // Get target node info and set next round robbin node. // nextNode is always lastNode + 1 mod (numOfNodes), to loop back to zero + targetNodeEndpoint, isRemoteConn := rr.getTargetAndIncrement(&cluster) + + // Get conn for this node, otherwise create new conn + timedSDKConn, ok := rr.getNodeConnection(targetNodeEndpoint) + if !ok { + var err error + rrlogger.WithContext(ctx).Infof("Round-robin connecting to node %s:%s", targetNodeEndpoint, rr.grpcServerPort) + remoteConn, err := grpcserver.ConnectWithTimeout( + fmt.Sprintf("%s:%s", targetNodeEndpoint, rr.grpcServerPort), + []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithUnaryInterceptor(correlation.ContextUnaryClientInterceptor), + }, 10*time.Second) + if err != nil { + return nil, isRemoteConn, err + } + timedSDKConn = &TimedSDKConn{ + Conn: remoteConn, + } + + rr.setNodeConnection(targetNodeEndpoint, timedSDKConn) + } + + // Keep track of when this conn was last accessed + rrlogger.WithContext(ctx).Infof("Using remote connection to SDK node %s:%s", targetNodeEndpoint, rr.grpcServerPort) + timedSDKConn.LastUsage = time.Now() + return timedSDKConn.Conn, isRemoteConn, nil + +} + +func (rr *roundRobin) getTargetAndIncrement(cluster *api.Cluster) (string, bool) { + rr.mu.Lock() + defer rr.mu.Unlock() var ( targetNodeNumber int isRemoteConn bool @@ -101,36 +131,36 @@ func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.Client targetNodeEndpoint := targetNode.MgmtIp rr.nextCreateNodeNumber = (targetNodeNumber + 1) % len(cluster.Nodes) - // Get conn for this node, otherwise create new conn + return targetNodeEndpoint, isRemoteConn +} + +func (rr *roundRobin) getNodeConnection(targetNodeEndpoint string) (*TimedSDKConn, bool) { if len(rr.connMap) == 0 { rr.connMap = make(map[string]*TimedSDKConn) } - if rr.connMap[targetNodeEndpoint] == nil { - var err error - rrlogger.WithContext(ctx).Infof("Round-robin connecting to node %v - %s:%s", targetNodeNumber, targetNodeEndpoint, rr.grpcServerPort) - remoteConn, err := grpcserver.ConnectWithTimeout( - fmt.Sprintf("%s:%s", targetNodeEndpoint, rr.grpcServerPort), - []grpc.DialOption{ - grpc.WithInsecure(), - grpc.WithUnaryInterceptor(correlation.ContextUnaryClientInterceptor), - }, 10*time.Second) - if err != nil { - return nil, isRemoteConn, err - } - rr.connMap[targetNodeEndpoint] = &TimedSDKConn{ - Conn: remoteConn, - } - } + rr.mu.RLock() + timedSDKConn, ok := rr.connMap[targetNodeEndpoint] + rr.mu.RUnlock() - // Keep track of when this conn was last accessed - rrlogger.WithContext(ctx).Infof("Using remote connection to SDK node %v - %s:%s", targetNodeNumber, targetNodeEndpoint, rr.grpcServerPort) - rr.connMap[targetNodeEndpoint].LastUsage = time.Now() - return rr.connMap[targetNodeEndpoint].Conn, isRemoteConn, nil + return timedSDKConn, ok +} + +func (rr *roundRobin) setNodeConnection(targetNodeEndpoint string, tsc *TimedSDKConn) { + if len(rr.connMap) == 0 { + rr.connMap = make(map[string]*TimedSDKConn) + } + rr.mu.Lock() + rr.connMap[targetNodeEndpoint] = tsc + rr.mu.Unlock() } -func (rr *roundRobin) cleanupMissingNodeConnections(ctx context.Context, nodes []*api.Node) { +func (rr *roundRobin) cleanupMissingNodeConnections(ctx context.Context, nodes []*api.Node) int { + rr.mu.Lock() + defer rr.mu.Unlock() + + numConnsClosed := 0 nodesMap := make(map[string]bool) for _, node := range nodes { nodesMap[node.MgmtIp] = true @@ -142,19 +172,18 @@ func (rr *roundRobin) cleanupMissingNodeConnections(ctx context.Context, nodes [ rrlogger.WithContext(ctx).Errorf("failed to close conn to %s: %v", ip, err) } delete(rr.connMap, ip) + numConnsClosed++ } } + + return numConnsClosed } -func (rr *roundRobin) cleanupConnections() { - ctx := correlation.WithCorrelationContext(context.Background(), correlation.ComponentRoundRobinBalancer) +func (rr *roundRobin) cleanupExpiredConnections() int { rr.mu.Lock() defer rr.mu.Unlock() - - rrlogger.Tracef("Cleaning up open gRPC connections created for round-robin balancing.") - - // Clean all expired connections numConnsClosed := 0 + for ip, timedConn := range rr.connMap { expiryTime := timedConn.LastUsage.Add(connIdleConnLength) @@ -170,6 +199,19 @@ func (rr *roundRobin) cleanupConnections() { } } + return numConnsClosed +} + +func (rr *roundRobin) cleanupConnections() { + ctx := correlation.WithCorrelationContext(context.Background(), correlation.ComponentRoundRobinBalancer) + rrlogger.Tracef("Cleaning up open gRPC connections created for round-robin balancing.") + + // Clean all expired connections + expiredConnsClosed := rr.cleanupExpiredConnections() + if expiredConnsClosed > 0 { + rrlogger.Infof("Cleaned up %v expired node connections created for round-robin balancing. %v connections remaining", expiredConnsClosed, len(rr.connMap)) + } + // Get all nodes and cleanup conns for missing/decommissioned nodes nodesResp, err := rr.cluster.Enumerate() if err != nil { @@ -180,10 +222,9 @@ func (rr *roundRobin) cleanupConnections() { rrlogger.Errorf("no nodes available to cleanup: %v", err) return } - rr.cleanupMissingNodeConnections(ctx, nodesResp.Nodes) - - if numConnsClosed > 0 { - rrlogger.Infof("Cleaned up %v connections created for round-robin balancing. %v connections remaining", numConnsClosed, len(rr.connMap)) + missingNodeConnsClosed := rr.cleanupMissingNodeConnections(ctx, nodesResp.Nodes) + if missingNodeConnsClosed > 0 { + rrlogger.Infof("Cleaned up %v connections for missing nodes created for round-robin balancing. %v connections remaining", missingNodeConnsClosed, len(rr.connMap)) } }