diff --git a/go/vt/vtgateproxy/gate_balancer.go b/go/vt/vtgateproxy/gate_balancer.go index 523e4c4c44b..494df3b0578 100644 --- a/go/vt/vtgateproxy/gate_balancer.go +++ b/go/vt/vtgateproxy/gate_balancer.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strconv" "sync" "sync/atomic" @@ -16,12 +17,12 @@ import ( // Name is the name of az affinity balancer. const Name = "slack_affinity_balancer" const MetadataAZKey = "grpc-slack-az-metadata" -const MetadataGateTypeKey = "grpc-slack-gate-type-metadata" +const MetadataHostAffinityCount = "grpc-slack-num-connections-metadata" var logger = grpclog.Component("slack_affinity_balancer") -func WithSlackAZAffinityContext(ctx context.Context, azID string, gateType string) context.Context { - ctx = metadata.AppendToOutgoingContext(ctx, MetadataAZKey, azID, MetadataGateTypeKey, gateType) +func WithSlackAZAffinityContext(ctx context.Context, azID string, numConnections string) context.Context { + ctx = metadata.AppendToOutgoingContext(ctx, MetadataAZKey, azID, MetadataHostAffinityCount, numConnections) return ctx } @@ -82,21 +83,24 @@ func (p *slackAZAffinityPicker) pickFromSubconns(scList []balancer.SubConn, next func (p *slackAZAffinityPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { hdrs, _ := metadata.FromOutgoingContext(info.Ctx) - fmt.Printf("Headers: %v %v\n", hdrs, info) + numConnections := 0 keys := hdrs.Get(MetadataAZKey) if len(keys) < 1 { - fmt.Printf("uh oh - missing keys: %v %v %v\n", keys, hdrs, info.Ctx) - fmt.Printf("no header - pick from anywhere\n") return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) } az := keys[0] if az == "" { - fmt.Printf("Header unset, pick from anywhere\n") return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) } - fmt.Printf("Selecting from az: %v\n", az) + keys = hdrs.Get(MetadataHostAffinityCount) + if len(keys) > 0 { + if i, err := strconv.Atoi(keys[0]); err != nil { + numConnections = i + } + } + subConns := p.subConnsByAZ[az] if len(subConns) == 0 { fmt.Printf("No subconns in az and gate type, pick from anywhere\n") @@ -106,9 +110,9 @@ func (p *slackAZAffinityPicker) Pick(info balancer.PickInfo) (balancer.PickResul ptr := val.(*uint32) atomic.AddUint32(ptr, 1) - if len(subConns) >= 2 { - fmt.Printf("Limiting to first 2\n") - return p.pickFromSubconns(subConns[0:2], *ptr) + if len(subConns) >= numConnections { + fmt.Printf("Limiting to first %v\n", numConnections) + return p.pickFromSubconns(subConns[0:numConnections], *ptr) } else { return p.pickFromSubconns(subConns, *ptr) } diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index 2c565d67022..997e3d419f6 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -52,10 +52,11 @@ var ( ) type VTGateProxy struct { - targetConns map[string]*vtgateconn.VTGateConn - mu sync.Mutex - azID string - gateType string + targetConns map[string]*vtgateconn.VTGateConn + mu sync.Mutex + azID string + gateType string + numConnections string } func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vtgateconn.VTGateConn, error) { @@ -65,6 +66,7 @@ func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vt } proxy.azID = targetURL.Query().Get("az_id") + proxy.numConnections = targetURL.Query().Get("num_connections") proxy.gateType = targetURL.Host fmt.Printf("Getting connection for %v in %v\n", target, proxy.azID) @@ -87,7 +89,7 @@ func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vt return append(opts, grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"slack_affinity_balancer":{}}]}`)), nil }) - conn, err := vtgateconn.DialProtocol(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType), "grpc", target) + conn, err := vtgateconn.DialProtocol(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.numConnections), "grpc", target) if err != nil { return nil, err }