diff --git a/go/vt/vtgateproxy/discovery.go b/go/vt/vtgateproxy/discovery.go index b282abc9edb..532823e8330 100644 --- a/go/vt/vtgateproxy/discovery.go +++ b/go/vt/vtgateproxy/discovery.go @@ -9,7 +9,7 @@ import ( "io" "math/rand" "os" - "strconv" + "strings" "time" "google.golang.org/grpc/attributes" @@ -18,6 +18,7 @@ import ( var ( jsonDiscoveryConfig = flag.String("json_config", "", "json file describing the host list to use fot vitess://vtgate resolution") + numConnectionsInt = flag.Int("num_connections", 4, "number of outbound GPRC connections to maintain") ) // File based discovery for vtgate grpc endpoints @@ -54,33 +55,27 @@ type JSONGateConfigDiscovery struct { JsonPath string } +const queryParamFilterPrefix = "filter_" + func (b *JSONGateConfigDiscovery) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { fmt.Printf("Start registration for target: %v\n", target.URL.String()) queryOpts := target.URL.Query() - queryParamCount := queryOpts.Get("num_connections") - queryAZID := queryOpts.Get("az_id") - num_connections := 0 - gateType := target.URL.Host - if queryParamCount != "" { - num_connections, _ = strconv.Atoi(queryParamCount) - } - - filters := resolveFilters{ - gate_type: gateType, - } - - if queryAZID != "" { - filters.az_id = queryAZID + filters := hostFilters{} + filters["type"] = gateType + for k, _ := range queryOpts { + if strings.HasPrefix(k, queryParamFilterPrefix) { + filteredPrefix := strings.TrimPrefix(k, queryParamFilterPrefix) + filters[filteredPrefix] = queryOpts.Get(k) + } } r := &resolveJSONGateConfig{ - target: target, - cc: cc, - jsonPath: b.JsonPath, - num_connections: num_connections, - filters: filters, + target: target, + cc: cc, + jsonPath: b.JsonPath, + filters: filters, } r.start() return r, nil @@ -101,23 +96,25 @@ type resolveFilters struct { az_id string } +type hostFilters = map[string]string + // exampleResolver is a // Resolver(https://godoc.org/google.golang.org/grpc/resolver#Resolver). type resolveJSONGateConfig struct { - target resolver.Target - cc resolver.ClientConn - jsonPath string - ticker *time.Ticker - rand *rand.Rand // safe for concurrent use. - num_connections int - filters resolveFilters + target resolver.Target + cc resolver.ClientConn + jsonPath string + ticker *time.Ticker + rand *rand.Rand // safe for concurrent use. + filters hostFilters } type discoverySlackAZ struct{} type discoverySlackType struct{} +type matchesFilter struct{} func (r *resolveJSONGateConfig) loadConfig() (*[]resolver.Address, []byte, error) { - config := []DiscoveryHost{} + pairs := []map[string]string{} fmt.Printf("Loading config %v\n", r.jsonPath) data, err := os.ReadFile(r.jsonPath) @@ -125,27 +122,28 @@ func (r *resolveJSONGateConfig) loadConfig() (*[]resolver.Address, []byte, error return nil, nil, err } - err = json.Unmarshal(data, &config) + err = json.Unmarshal(data, &pairs) if err != nil { fmt.Printf("parse err: %v\n", err) return nil, nil, err } addrs := []resolver.Address{} - for _, s := range config { - az := attributes.New(discoverySlackAZ{}, s.AZId).WithValue(discoverySlackType{}, s.Type) + for _, pair := range pairs { + attributes := attributes.New(matchesFilter{}, true) - // Filter hosts to this gate type - if r.filters.gate_type != "" { - if r.filters.gate_type != s.Type { + for k, v := range r.filters { + if pair[k] != v { + fmt.Printf("Filtering out %v", pair) + attributes.WithValue(matchesFilter{}, false) continue } } // Add matching hosts to registration list addrs = append(addrs, resolver.Address{ - Addr: fmt.Sprintf("%s:%s", s.NebulaAddress, s.Grpc), - BalancerAttributes: az, + Addr: fmt.Sprintf("%s:%s", pair["nebula_address"], pair["grpc"]), + BalancerAttributes: attributes, }) } diff --git a/go/vt/vtgateproxy/gate_balancer.go b/go/vt/vtgateproxy/gate_balancer.go index 77f8de98c19..5045622f407 100644 --- a/go/vt/vtgateproxy/gate_balancer.go +++ b/go/vt/vtgateproxy/gate_balancer.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "strconv" + "strings" "sync" "sync/atomic" @@ -16,13 +17,14 @@ import ( // Name is the name of az affinity balancer. const Name = "slack_affinity_balancer" -const MetadataAZKey = "grpc-slack-az-metadata" const MetadataHostAffinityCount = "grpc-slack-num-connections-metadata" +const MetadataDiscoveryFilterPrefix = "grpc_discovery_filter_" var logger = grpclog.Component("slack_affinity_balancer") -func WithSlackAZAffinityContext(ctx context.Context, azID string, numConnections string) context.Context { - ctx = metadata.AppendToOutgoingContext(ctx, MetadataAZKey, azID, MetadataHostAffinityCount, numConnections) +func WithSlackAZAffinityContext(ctx context.Context, numConnections string, filters metadata.MD) context.Context { + metadata.NewOutgoingContext(ctx, filters) + ctx = metadata.AppendToOutgoingContext(ctx, MetadataHostAffinityCount, numConnections) return ctx } @@ -44,27 +46,30 @@ func (*slackAZAffinityBalancer) Build(info base.PickerBuildInfo) balancer.Picker return base.NewErrPicker(balancer.ErrNoSubConnAvailable) } allSubConns := []balancer.SubConn{} - subConnsByAZ := map[string][]balancer.SubConn{} + subConnsByFiltered := []balancer.SubConn{} for sc := range info.ReadySCs { subConnInfo, _ := info.ReadySCs[sc] - az := subConnInfo.Address.BalancerAttributes.Value(discoverySlackAZ{}).(string) + matchesFilter := subConnInfo.Address.BalancerAttributes.Value(matchesFilter{}).(bool) allSubConns = append(allSubConns, sc) - subConnsByAZ[az] = append(subConnsByAZ[az], sc) + if matchesFilter { + subConnsByFiltered = append(subConnsByFiltered, sc) + } + } return &slackAZAffinityPicker{ - allSubConns: allSubConns, - subConnsByAZ: subConnsByAZ, + allSubConns: allSubConns, + filteredSubConns: subConnsByFiltered, } } type slackAZAffinityPicker struct { // allSubConns is all subconns that were in the ready state when the picker was created - allSubConns []balancer.SubConn - subConnsByAZ map[string][]balancer.SubConn - nextByAZ sync.Map - next uint32 + allSubConns []balancer.SubConn + filteredSubConns []balancer.SubConn + nextByAZ sync.Map + next uint32 } // Pick the next in the list from the list of subconns (RR) @@ -90,6 +95,18 @@ func (p *slackAZAffinityPicker) Pick(info balancer.PickInfo) (balancer.PickResul } az := keys[0] + filteredSubconns := p.allSubConns + for k, v := range hdrs { + if strings.HasPrefix(k, MetadataDiscoveryFilterPrefix) { + filterName := strings.TrimPrefix(k, MetadataDiscoveryFilterPrefix) + filterValue := v + } + } + + for _, s := range v { + + } + if az == "" { return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) } diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index 68869dbd323..0ac2885adaa 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -23,12 +23,13 @@ import ( "flag" "fmt" "io" - "net/url" + "strconv" "strings" "sync" "time" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/grpcclient" @@ -41,8 +42,7 @@ import ( ) var ( - dialTimeout = flag.Duration("dial_timeout", 5*time.Second, "dialer timeout for the GRPC connection") - + dialTimeout = flag.Duration("dial_timeout", 5*time.Second, "dialer timeout for the GRPC connection") defaultDDLStrategy = flag.String("ddl_strategy", string(schema.DDLStrategyDirect), "Set default strategy for DDL statements. Override with @@ddl_strategy session variable") sysVarSetEnabled = flag.Bool("enable_system_settings", true, "This will enable the system settings to be changed per session at the database connection level") @@ -53,24 +53,13 @@ var ( ) type VTGateProxy struct { - targetConns map[string]*vtgateconn.VTGateConn - mu sync.Mutex - azID string - gateType string - numConnections string + targetConns map[string]*vtgateconn.VTGateConn + mu sync.Mutex } -func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vtgateconn.VTGateConn, error) { - targetURL, err := url.Parse(target) - if err != nil { - return nil, err - } - - proxy.azID = targetURL.Query().Get("az_id") - proxy.numConnections = targetURL.Query().Get("num_connections") - proxy.gateType = targetURL.Host - - fmt.Printf("Getting connection for %v in %v with %v connections\n", target, proxy.azID, proxy.numConnections) +func (proxy *VTGateProxy) getConnection(ctx context.Context, target string, filters metadata.MD) (*vtgateconn.VTGateConn, error) { + numConnectionsString := strconv.Itoa(*numConnectionsInt) + fmt.Printf("Getting connection for %v in %v with %v filters\n", target, filters) // If the connection exists, return it proxy.mu.Lock() @@ -90,7 +79,7 @@ func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vt return append(opts, grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"slack_affinity_balancer":{}}]}`)), nil }) - conn, err := vtgateconn.DialProtocol(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.numConnections), "grpc", target) + conn, err := vtgateconn.DialProtocol(WithSlackAZAffinityContext(ctx, numConnectionsString, filters), "grpc", target) if err != nil { return nil, err } @@ -108,7 +97,14 @@ func (proxy *VTGateProxy) NewSession(ctx context.Context, options *querypb.Execu return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "no target string supplied by client") } - conn, err := proxy.getConnection(ctx, target) + filters := metadata.Pairs() + for k, v := range connectionAttributes { + if strings.HasPrefix(k, MetadataDiscoveryFilterPrefix) { + filters.Append(k, v) + } + } + + conn, err := proxy.getConnection(ctx, target, filters) if err != nil { return nil, err }