diff --git a/go/vt/vtgateproxy/discovery.go b/go/vt/vtgateproxy/discovery.go index b0497628b49..258d8d56cb8 100644 --- a/go/vt/vtgateproxy/discovery.go +++ b/go/vt/vtgateproxy/discovery.go @@ -24,10 +24,8 @@ import ( "io" "math/rand" "os" - "strings" "time" - "google.golang.org/grpc/attributes" "google.golang.org/grpc/resolver" "vitess.io/vitess/go/vt/log" @@ -35,7 +33,9 @@ import ( var ( jsonDiscoveryConfig = flag.String("json_config", "", "json file describing the host list to use fot vitess://vtgate resolution") - numConnectionsInt = flag.Int("num_connections", 4, "number of outbound GPRC connections to maintain") + addressField = flag.String("address_field", "address", "field name in the json file containing the address") + portField = flag.String("port_field", "port", "field name in the json file containing the port") + numConnections = flag.Int("num_connections", 4, "number of outbound GPRC connections to maintain") ) // File based discovery for vtgate grpc endpoints @@ -75,24 +75,30 @@ type JSONGateConfigDiscovery struct { const queryParamFilterPrefix = "filter_" func (b *JSONGateConfigDiscovery) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { - log.V(100).Infof("Start registration for target: %v\n", target.URL.String()) - queryOpts := target.URL.Query() - gateType := target.URL.Host - - filters := hostFilters{} - filters["type"] = gateType - for k := range queryOpts { - if strings.HasPrefix(k, queryParamFilterPrefix) { - filteredPrefix := strings.TrimPrefix(k, queryParamFilterPrefix) - filters[filteredPrefix] = queryOpts.Get(k) + attrs := target.URL.Query() + + poolType := "" + if *poolTypeAttr != "" { + poolType = attrs.Get(*poolTypeAttr) + if poolType == "" { + return nil, fmt.Errorf("pool type attribute %s not in target", *poolTypeAttr) } } + // affinity is optional + affinity := "" + if *affinityAttr != "" { + affinity = attrs.Get(*affinityAttr) + } + + log.V(100).Infof("Start discovery for target %v poolType %s affinity %s\n", target.URL.String(), poolType, affinity) + r := &JSONGateConfigResolver{ target: target, cc: cc, jsonPath: b.JsonPath, - filters: filters, + poolType: poolType, + affinity: affinity, } r.start() return r, nil @@ -115,85 +121,89 @@ type JSONGateConfigResolver struct { target resolver.Target cc resolver.ClientConn jsonPath string - ticker *time.Ticker - rand *rand.Rand // safe for concurrent use. - filters hostFilters + poolType string + affinity string + + ticker *time.Ticker + rand *rand.Rand // safe for concurrent use. } -type matchesFilter struct{} +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func jsonDump(data interface{}) string { + json, _ := json.Marshal(data) + return string(json) +} func (r *JSONGateConfigResolver) resolve() (*[]resolver.Address, []byte, error) { - pairs := []map[string]interface{}{} - log.V(100).Infof("resolving target %s to %d connections\n", r.target.URL.String(), *numConnectionsInt) + log.V(100).Infof("resolving target %s to %d connections\n", r.target.URL.String(), *numConnections) data, err := os.ReadFile(r.jsonPath) if err != nil { return nil, nil, err } - err = json.Unmarshal(data, &pairs) + hosts := []map[string]interface{}{} + err = json.Unmarshal(data, &hosts) if err != nil { log.Errorf("error parsing JSON discovery file %s: %v\n", r.jsonPath, err) return nil, nil, err } - allAddrs := []resolver.Address{} - filteredAddrs := []resolver.Address{} - var addrs []resolver.Address - for _, pair := range pairs { - matchesAll := true - for k, v := range r.filters { - if pair[k] != v { - matchesAll = false + // optionally filter to only hosts that match the pool type + if r.poolType != "" { + candidates := []map[string]interface{}{} + for _, host := range hosts { + hostType, ok := host[*poolTypeAttr] + if ok && hostType == r.poolType { + candidates = append(candidates, host) + log.V(1000).Infof("matched host %s with type %s", jsonDump(host), hostType) + } else { + log.V(1000).Infof("skipping host %s with type %s", jsonDump(host), hostType) } } + hosts = candidates + } - if matchesAll { - filteredAddrs = append(filteredAddrs, resolver.Address{ - Addr: fmt.Sprintf("%s:%s", pair["nebula_address"], pair["grpc"]), - BalancerAttributes: attributes.New(matchesFilter{}, "match"), - }) - } + // Shuffle to ensure every host has a different order to iterate through + r.rand.Shuffle(len(hosts), func(i, j int) { + hosts[i], hosts[j] = hosts[j], hosts[i] + }) - // Must filter by type - t, ok := r.filters["type"] - if ok { - if pair["type"] == t { - // Add matching hosts to registration list - allAddrs = append(allAddrs, resolver.Address{ - Addr: fmt.Sprintf("%s:%s", pair["nebula_address"], pair["grpc"]), - BalancerAttributes: attributes.New(matchesFilter{}, "nomatch"), - }) + // If affinity is specified, then shuffle those hosts to the front + if r.affinity != "" { + i := 0 + for j := 0; j < len(hosts); j++ { + hostAffinity, ok := hosts[j][*affinityAttr] + if ok && hostAffinity == r.affinity { + hosts[i], hosts[j] = hosts[j], hosts[i] + i++ } } } - // Nothing in the filtered list? Get them all - if len(filteredAddrs) == 0 { - addrs = allAddrs - } else if *numConnectionsInt == 0 { - addrs = allAddrs - } else if len(filteredAddrs) > *numConnectionsInt { - addrs = filteredAddrs[0:*numConnectionsInt] - } else if len(allAddrs) > *numConnectionsInt { - addrs = allAddrs[0:*numConnectionsInt] - } else { - addrs = allAddrs + // Grab the first N addresses, and voila! + var addrs []resolver.Address + hosts = hosts[:min(*numConnections, len(hosts))] + for _, host := range hosts { + addrs = append(addrs, resolver.Address{ + Addr: fmt.Sprintf("%s:%s", host[*addressField], host[*portField]), + }) } - // Shuffle to ensure every host has a different order to iterate through - r.rand.Shuffle(len(addrs), func(i, j int) { - addrs[i], addrs[j] = addrs[j], addrs[i] - }) - h := sha256.New() if _, err := io.Copy(h, bytes.NewReader(data)); err != nil { return nil, nil, err } sum := h.Sum(nil) - log.V(100).Infof("resolved %s to addrs: 0x%x, %v\n", r.target.URL.String(), sum, addrs) + log.V(100).Infof("resolved %s to hosts %s addrs: 0x%x, %v\n", r.target.URL.String(), jsonDump(hosts), sum, addrs) return &addrs, sum, nil } diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index bc142cb0d5d..c91000a90df 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -25,26 +25,22 @@ import ( "net/url" "strings" "sync" - "time" "google.golang.org/grpc" - "google.golang.org/grpc/metadata" "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/grpcclient" "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/schema" "vitess.io/vitess/go/vt/vterrors" _ "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" "vitess.io/vitess/go/vt/vtgate/vtgateconn" ) var ( - dialTimeout = flag.Duration("dial_timeout", 5*time.Second, "dialer timeout for the GRPC connection") - defaultDDLStrategy = flag.String("ddl_strategy", string(schema.DDLStrategyDirect), "Set default strategy for DDL statements. Override with @@ddl_strategy session variable") - sysVarSetEnabled = flag.Bool("enable_system_settings", true, "This will enable the system settings to be changed per session at the database connection level") + poolTypeAttr = flag.String("pool_type_attr", "", "Attribute (both mysql connection and JSON file) used to specify the target vtgate type and filter the hosts, e.g. 'type'") + affinityAttr = flag.String("affinity_attr", "", "Attribute (both mysql protocol connection and JSON file) used to specify the routing affinity , e.g. 'az_id'") vtGateProxy *VTGateProxy = &VTGateProxy{ targetConns: map[string]*vtgateconn.VTGateConn{}, @@ -94,23 +90,22 @@ func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vt } func (proxy *VTGateProxy) NewSession(ctx context.Context, options *querypb.ExecuteOptions, connectionAttributes map[string]string) (*vtgateconn.VTGateSession, error) { - target, ok := connectionAttributes["target"] - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "no target string supplied by client") + + if *poolTypeAttr != "" { + _, ok := connectionAttributes[*poolTypeAttr] + if !ok { + return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "pool type attribute %s not supplied by client", *poolTypeAttr) + } } targetUrl := url.URL{ Scheme: "vtgate", - Host: target, + Host: "pool", } - filters := metadata.Pairs() values := url.Values{} for k, v := range connectionAttributes { - if strings.HasPrefix(k, queryParamFilterPrefix) { - filters.Append(k, v) - values.Set(k, v) - } + values.Set(k, v) } targetUrl.RawQuery = values.Encode()