Skip to content

Commit

Permalink
PWX-33631: Hold round-robin lock only when needed (#2335)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Griffiths <[email protected]>
  • Loading branch information
Grant Griffiths authored Sep 13, 2023
1 parent aa0e7d0 commit f769b71
Showing 1 changed file with 79 additions and 39 deletions.
118 changes: 79 additions & 39 deletions pkg/loadbalancer/roundrobin.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type roundRobin struct {
cluster cluster.Cluster
connMap map[string]*TimedSDKConn
nextCreateNodeNumber int
mu sync.Mutex
mu sync.RWMutex
grpcServerPort string
}

Expand Down Expand Up @@ -65,9 +65,6 @@ func NewRoundRobinBalancer(
}

func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) {
rr.mu.Lock()
defer rr.mu.Unlock()

// Get all nodes and sort them
cluster, err := rr.cluster.Enumerate()
if err != nil {
Expand All @@ -80,11 +77,41 @@ func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.Client
return cluster.Nodes[i].Id < cluster.Nodes[j].Id
})

// Clean up connections for missing nodes
rr.cleanupMissingNodeConnections(ctx, cluster.Nodes)

// Get target node info and set next round robbin node.
// nextNode is always lastNode + 1 mod (numOfNodes), to loop back to zero
targetNodeEndpoint, isRemoteConn := rr.getTargetAndIncrement(&cluster)

// Get conn for this node, otherwise create new conn
timedSDKConn, ok := rr.getNodeConnection(targetNodeEndpoint)
if !ok {
var err error
rrlogger.WithContext(ctx).Infof("Round-robin connecting to node %s:%s", targetNodeEndpoint, rr.grpcServerPort)
remoteConn, err := grpcserver.ConnectWithTimeout(
fmt.Sprintf("%s:%s", targetNodeEndpoint, rr.grpcServerPort),
[]grpc.DialOption{
grpc.WithInsecure(),
grpc.WithUnaryInterceptor(correlation.ContextUnaryClientInterceptor),
}, 10*time.Second)
if err != nil {
return nil, isRemoteConn, err
}
timedSDKConn = &TimedSDKConn{
Conn: remoteConn,
}

rr.setNodeConnection(targetNodeEndpoint, timedSDKConn)
}

// Keep track of when this conn was last accessed
rrlogger.WithContext(ctx).Infof("Using remote connection to SDK node %s:%s", targetNodeEndpoint, rr.grpcServerPort)
timedSDKConn.LastUsage = time.Now()
return timedSDKConn.Conn, isRemoteConn, nil

}

func (rr *roundRobin) getTargetAndIncrement(cluster *api.Cluster) (string, bool) {
rr.mu.Lock()
defer rr.mu.Unlock()
var (
targetNodeNumber int
isRemoteConn bool
Expand All @@ -101,36 +128,38 @@ func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.Client
targetNodeEndpoint := targetNode.MgmtIp
rr.nextCreateNodeNumber = (targetNodeNumber + 1) % len(cluster.Nodes)

// Get conn for this node, otherwise create new conn
return targetNodeEndpoint, isRemoteConn
}

func (rr *roundRobin) getNodeConnection(targetNodeEndpoint string) (*TimedSDKConn, bool) {
if len(rr.connMap) == 0 {
rr.mu.Lock()
rr.connMap = make(map[string]*TimedSDKConn)
rr.mu.Unlock()
}
if rr.connMap[targetNodeEndpoint] == nil {
var err error
rrlogger.WithContext(ctx).Infof("Round-robin connecting to node %v - %s:%s", targetNodeNumber, targetNodeEndpoint, rr.grpcServerPort)
remoteConn, err := grpcserver.ConnectWithTimeout(
fmt.Sprintf("%s:%s", targetNodeEndpoint, rr.grpcServerPort),
[]grpc.DialOption{
grpc.WithInsecure(),
grpc.WithUnaryInterceptor(correlation.ContextUnaryClientInterceptor),
}, 10*time.Second)
if err != nil {
return nil, isRemoteConn, err
}

rr.connMap[targetNodeEndpoint] = &TimedSDKConn{
Conn: remoteConn,
}
}
rr.mu.RLock()
timedSDKConn, ok := rr.connMap[targetNodeEndpoint]
rr.mu.RUnlock()

// Keep track of when this conn was last accessed
rrlogger.WithContext(ctx).Infof("Using remote connection to SDK node %v - %s:%s", targetNodeNumber, targetNodeEndpoint, rr.grpcServerPort)
rr.connMap[targetNodeEndpoint].LastUsage = time.Now()
return rr.connMap[targetNodeEndpoint].Conn, isRemoteConn, nil
return timedSDKConn, ok
}

func (rr *roundRobin) setNodeConnection(targetNodeEndpoint string, tsc *TimedSDKConn) {
rr.mu.Lock()
defer rr.mu.Unlock()

if len(rr.connMap) == 0 {
rr.connMap = make(map[string]*TimedSDKConn)
}
rr.connMap[targetNodeEndpoint] = tsc
}

func (rr *roundRobin) cleanupMissingNodeConnections(ctx context.Context, nodes []*api.Node) {
func (rr *roundRobin) cleanupMissingNodeConnections(ctx context.Context, nodes []*api.Node) int {
rr.mu.Lock()
defer rr.mu.Unlock()

numConnsClosed := 0
nodesMap := make(map[string]bool)
for _, node := range nodes {
nodesMap[node.MgmtIp] = true
Expand All @@ -142,19 +171,18 @@ func (rr *roundRobin) cleanupMissingNodeConnections(ctx context.Context, nodes [
rrlogger.WithContext(ctx).Errorf("failed to close conn to %s: %v", ip, err)
}
delete(rr.connMap, ip)
numConnsClosed++
}
}

return numConnsClosed
}

func (rr *roundRobin) cleanupConnections() {
ctx := correlation.WithCorrelationContext(context.Background(), correlation.ComponentRoundRobinBalancer)
func (rr *roundRobin) cleanupExpiredConnections() int {
rr.mu.Lock()
defer rr.mu.Unlock()

rrlogger.Tracef("Cleaning up open gRPC connections created for round-robin balancing.")

// Clean all expired connections
numConnsClosed := 0

for ip, timedConn := range rr.connMap {
expiryTime := timedConn.LastUsage.Add(connIdleConnLength)

Expand All @@ -170,6 +198,19 @@ func (rr *roundRobin) cleanupConnections() {
}
}

return numConnsClosed
}

func (rr *roundRobin) cleanupConnections() {
ctx := correlation.WithCorrelationContext(context.Background(), correlation.ComponentRoundRobinBalancer)
rrlogger.Tracef("Cleaning up open gRPC connections created for round-robin balancing.")

// Clean all expired connections
expiredConnsClosed := rr.cleanupExpiredConnections()
if expiredConnsClosed > 0 {
rrlogger.Infof("Cleaned up %v expired node connections created for round-robin balancing. %v connections remaining", expiredConnsClosed, len(rr.connMap))
}

// Get all nodes and cleanup conns for missing/decommissioned nodes
nodesResp, err := rr.cluster.Enumerate()
if err != nil {
Expand All @@ -180,10 +221,9 @@ func (rr *roundRobin) cleanupConnections() {
rrlogger.Errorf("no nodes available to cleanup: %v", err)
return
}
rr.cleanupMissingNodeConnections(ctx, nodesResp.Nodes)

if numConnsClosed > 0 {
rrlogger.Infof("Cleaned up %v connections created for round-robin balancing. %v connections remaining", numConnsClosed, len(rr.connMap))
missingNodeConnsClosed := rr.cleanupMissingNodeConnections(ctx, nodesResp.Nodes)
if missingNodeConnsClosed > 0 {
rrlogger.Infof("Cleaned up %v connections for missing nodes created for round-robin balancing. %v connections remaining", missingNodeConnsClosed, len(rr.connMap))
}
}

Expand Down

0 comments on commit f769b71

Please sign in to comment.