diff --git a/go/cmd/vtgateproxy/vtgateproxy.go b/go/cmd/vtgateproxy/vtgateproxy.go index a2763e2a6c7..22d556b5a44 100644 --- a/go/cmd/vtgateproxy/vtgateproxy.go +++ b/go/cmd/vtgateproxy/vtgateproxy.go @@ -17,9 +17,13 @@ limitations under the License. package main import ( + "log" "math/rand" + "net" "time" + "google.golang.org/grpc" + "google.golang.org/grpc/channelz/service" "vitess.io/vitess/go/exit" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vtgateproxy" @@ -38,6 +42,14 @@ func main() { servenv.ParseFlags("vtgateproxy") servenv.Init() + lis, err := net.Listen("tcp", "localhost:8153") + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + s := grpc.NewServer() + service.RegisterChannelzServiceToServer(s) + go s.Serve(lis) + servenv.OnRun(func() { // Flags are parsed now. Parse the template using the actual flag value and overwrite the current template. vtgateproxy.RegisterJsonDiscovery() diff --git a/go/vt/vtgateproxy/discovery.go b/go/vt/vtgateproxy/discovery.go index 055a2a1f677..b282abc9edb 100644 --- a/go/vt/vtgateproxy/discovery.go +++ b/go/vt/vtgateproxy/discovery.go @@ -1,14 +1,18 @@ package vtgateproxy import ( + "bytes" + "crypto/sha256" "encoding/json" "flag" "fmt" + "io" "math/rand" "os" "strconv" "time" + "google.golang.org/grpc/attributes" "google.golang.org/grpc/resolver" ) @@ -51,6 +55,7 @@ type JSONGateConfigDiscovery struct { } func (b *JSONGateConfigDiscovery) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { + fmt.Printf("Start registration for target: %v\n", target.URL.String()) queryOpts := target.URL.Query() queryParamCount := queryOpts.Get("num_connections") queryAZID := queryOpts.Get("az_id") @@ -84,9 +89,11 @@ func (*JSONGateConfigDiscovery) Scheme() string { return "vtgate" } func RegisterJsonDiscovery() { fmt.Printf("Registering: %v\n", *jsonDiscoveryConfig) - resolver.Register(&JSONGateConfigDiscovery{ + jsonDiscovery := &JSONGateConfigDiscovery{ JsonPath: *jsonDiscoveryConfig, - }) + } + resolver.Register(jsonDiscovery) + fmt.Printf("Registered %v scheme\n", jsonDiscovery.Scheme()) } type resolveFilters struct { @@ -106,56 +113,56 @@ type resolveJSONGateConfig struct { filters resolveFilters } -func (r *resolveJSONGateConfig) loadConfig() (*[]resolver.Address, error) { +type discoverySlackAZ struct{} +type discoverySlackType struct{} + +func (r *resolveJSONGateConfig) loadConfig() (*[]resolver.Address, []byte, error) { config := []DiscoveryHost{} + fmt.Printf("Loading config %v\n", r.jsonPath) data, err := os.ReadFile(r.jsonPath) if err != nil { - return nil, err + return nil, nil, err } err = json.Unmarshal(data, &config) if err != nil { fmt.Printf("parse err: %v\n", err) - return nil, err + return nil, nil, err } - fmt.Printf("%v\n", config) - addrs := []resolver.Address{} for _, s := range config { - // Apply filters + az := attributes.New(discoverySlackAZ{}, s.AZId).WithValue(discoverySlackType{}, s.Type) + + // Filter hosts to this gate type if r.filters.gate_type != "" { if r.filters.gate_type != s.Type { - // fmt.Printf("Dropped non matching type: %v\n", s.Type) continue } } - if r.filters.az_id != "" { - if r.filters.az_id != s.AZId { - fmt.Printf("Dropped non matching az: %v\n", s.AZId) - continue - } - } // Add matching hosts to registration list - fmt.Printf("selected host for discovery: %v %v\n", fmt.Sprintf("%s:%s", s.NebulaAddress, s.Grpc), s) - addrs = append(addrs, resolver.Address{Addr: fmt.Sprintf("%s:%s", s.NebulaAddress, s.Grpc)}) + addrs = append(addrs, resolver.Address{ + Addr: fmt.Sprintf("%s:%s", s.NebulaAddress, s.Grpc), + BalancerAttributes: az, + }) } + fmt.Printf("Addrs: %v\n", addrs) + // Shuffle to ensure every host has a different order to iterate through r.rand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }) - // Slice off the first N hosts, optionally - if r.num_connections > 0 && r.num_connections <= len(addrs) { - addrs = addrs[0:r.num_connections] + h := sha256.New() + if _, err := io.Copy(h, bytes.NewReader(data)); err != nil { + return nil, nil, err } - fmt.Printf("Returning discovery: %v\n", addrs) - - return &addrs, nil + fmt.Printf("Returning discovery: %d hosts checksum %x\n", len(addrs), h.Sum(nil)) + return &addrs, h.Sum(nil), nil } func (r *resolveJSONGateConfig) start() { @@ -163,7 +170,7 @@ func (r *resolveJSONGateConfig) start() { r.rand = rand.New(rand.NewSource(time.Now().UnixNano())) // Immediately load the initial config - addrs, err := r.loadConfig() + addrs, hash, err := r.loadConfig() if err == nil { // if we parse ok, populate the local address store r.cc.UpdateState(resolver.State{Addresses: *addrs}) @@ -175,33 +182,46 @@ func (r *resolveJSONGateConfig) start() { if err != nil { return } - lastLoaded := time.Now() - go func() { for range r.ticker.C { checkFileStat, err := os.Stat(r.jsonPath) + if err != nil { + fmt.Printf("Error stat'ing config %v\n", err) + continue + } isUnchanged := checkFileStat.Size() == fileStat.Size() || checkFileStat.ModTime() == fileStat.ModTime() - isNotExpired := time.Since(lastLoaded) < 1*time.Minute - if isUnchanged && isNotExpired { + if isUnchanged { // no change continue } - lastLoaded = time.Now() fileStat = checkFileStat fmt.Printf("Detected config change\n") - addrs, err := r.loadConfig() + addrs, newHash, err := r.loadConfig() if err != nil { // better luck next loop // TODO: log this - fmt.Print("oh no\n") + fmt.Print("Can't load config: %v\n", err) + continue + } + + // Make sure this wasn't a spurious change by checking the hash + if bytes.Compare(hash, newHash) == 0 && newHash != nil { + fmt.Printf("No content changed in discovery file... ignoring\n") continue } + hash = newHash + + fmt.Printf("Loaded %d hosts\n", len(*addrs)) + fmt.Printf("Loaded %v", addrs) r.cc.UpdateState(resolver.State{Addresses: *addrs}) } }() + + fmt.Printf("Loaded hosts, starting ticker\n") + } func (r *resolveJSONGateConfig) ResolveNow(o resolver.ResolveNowOptions) {} func (r *resolveJSONGateConfig) Close() { diff --git a/go/vt/vtgateproxy/gate_balancer.go b/go/vt/vtgateproxy/gate_balancer.go new file mode 100644 index 00000000000..77f8de98c19 --- /dev/null +++ b/go/vt/vtgateproxy/gate_balancer.go @@ -0,0 +1,119 @@ +package vtgateproxy + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync" + "sync/atomic" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" +) + +// Name is the name of az affinity balancer. +const Name = "slack_affinity_balancer" +const MetadataAZKey = "grpc-slack-az-metadata" +const MetadataHostAffinityCount = "grpc-slack-num-connections-metadata" + +var logger = grpclog.Component("slack_affinity_balancer") + +func WithSlackAZAffinityContext(ctx context.Context, azID string, numConnections string) context.Context { + ctx = metadata.AppendToOutgoingContext(ctx, MetadataAZKey, azID, MetadataHostAffinityCount, numConnections) + return ctx +} + +func newBuilder() balancer.Builder { + return base.NewBalancerBuilder(Name, &slackAZAffinityBalancer{}, base.Config{HealthCheck: true}) +} + +func init() { + balancer.Register(newBuilder()) +} + +type slackAZAffinityBalancer struct{} + +func (*slackAZAffinityBalancer) Build(info base.PickerBuildInfo) balancer.Picker { + logger.Infof("slackAZAffinityBalancer: Build called with info: %v", info) + fmt.Printf("Rebuilding picker\n") + + if len(info.ReadySCs) == 0 { + return base.NewErrPicker(balancer.ErrNoSubConnAvailable) + } + allSubConns := []balancer.SubConn{} + subConnsByAZ := map[string][]balancer.SubConn{} + + for sc := range info.ReadySCs { + subConnInfo, _ := info.ReadySCs[sc] + az := subConnInfo.Address.BalancerAttributes.Value(discoverySlackAZ{}).(string) + + allSubConns = append(allSubConns, sc) + subConnsByAZ[az] = append(subConnsByAZ[az], sc) + } + return &slackAZAffinityPicker{ + allSubConns: allSubConns, + subConnsByAZ: subConnsByAZ, + } +} + +type slackAZAffinityPicker struct { + // allSubConns is all subconns that were in the ready state when the picker was created + allSubConns []balancer.SubConn + subConnsByAZ map[string][]balancer.SubConn + nextByAZ sync.Map + next uint32 +} + +// Pick the next in the list from the list of subconns (RR) +func (p *slackAZAffinityPicker) pickFromSubconns(scList []balancer.SubConn, nextIndex uint32) (balancer.PickResult, error) { + subConnsLen := uint32(len(scList)) + + if subConnsLen == 0 { + return balancer.PickResult{}, errors.New("No hosts in list") + } + + fmt.Printf("Select offset: %v %v %v\n", nextIndex, nextIndex%subConnsLen, len(scList)) + + sc := scList[nextIndex%subConnsLen] + return balancer.PickResult{SubConn: sc}, nil +} + +func (p *slackAZAffinityPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + hdrs, _ := metadata.FromOutgoingContext(info.Ctx) + numConnections := 0 + keys := hdrs.Get(MetadataAZKey) + if len(keys) < 1 { + return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) + } + az := keys[0] + + if az == "" { + return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) + } + + keys = hdrs.Get(MetadataHostAffinityCount) + if len(keys) > 0 { + if i, err := strconv.Atoi(keys[0]); err != nil { + numConnections = i + } + } + + subConns := p.subConnsByAZ[az] + if len(subConns) == 0 { + fmt.Printf("No subconns in az and gate type, pick from anywhere\n") + return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) + } + val, _ := p.nextByAZ.LoadOrStore(az, new(uint32)) + ptr := val.(*uint32) + atomic.AddUint32(ptr, 1) + + if len(subConns) >= numConnections && numConnections > 0 { + fmt.Printf("Limiting to first %v\n", numConnections) + return p.pickFromSubconns(subConns[0:numConnections], *ptr) + } else { + return p.pickFromSubconns(subConns, *ptr) + } +} diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index be5d91d430b..08b19f8c256 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -21,7 +21,9 @@ package vtgateproxy import ( "context" "flag" + "fmt" "io" + "net/url" "strings" "sync" "time" @@ -51,17 +53,31 @@ var ( ) type VTGateProxy struct { - targetConns map[string]*vtgateconn.VTGateConn - mu sync.Mutex + targetConns map[string]*vtgateconn.VTGateConn + mu sync.Mutex + azID string + gateType string + numConnections string } func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vtgateconn.VTGateConn, error) { + targetURL, err := url.Parse(target) + if err != nil { + return nil, err + } + + proxy.azID = targetURL.Query().Get("az_id") + proxy.numConnections = targetURL.Query().Get("num_connections") + proxy.gateType = targetURL.Host + + fmt.Printf("Getting connection for %v in %v with %v connections\n", target, proxy.azID, proxy.numConnections) + // If the connection exists, return it proxy.mu.Lock() - conn, _ := proxy.targetConns[target] - if conn != nil { + existingConn, _ := proxy.targetConns[target] + if existingConn != nil { proxy.mu.Unlock() - return conn, nil + return existingConn, nil } proxy.mu.Unlock() @@ -70,12 +86,11 @@ func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vt // grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) { // return append(opts, grpc.WithBlock()), nil // }) - grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) { - return append(opts, grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`)), nil + return append(opts, grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"slack_affinity_balancer":{}}]}`)), nil }) - conn, err := vtgateconn.DialProtocol(ctx, "grpc", target) + conn, err := vtgateconn.DialProtocol(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.numConnections), "grpc", target) if err != nil { return nil, err } @@ -105,7 +120,7 @@ func (proxy *VTGateProxy) NewSession(ctx context.Context, options *querypb.Execu // same effect as if a "rollback" statement was executed, but does not affect the query // statistics. func (proxy *VTGateProxy) CloseSession(ctx context.Context, session *vtgateconn.VTGateSession) error { - return session.CloseSession(ctx) + return session.CloseSession(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType)) } // ResolveTransaction resolves the specified 2PC transaction. @@ -119,7 +134,6 @@ func (proxy *VTGateProxy) Prepare(ctx context.Context, session *vtgateconn.VTGat } func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable) (qr *sqltypes.Result, err error) { - // Intercept "use" statements since they just have to update the local session if strings.HasPrefix(sql, "use ") { targetString := sqlescape.UnescapeID(sql[4:]) @@ -127,11 +141,19 @@ func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgateconn.VTGat return &sqltypes.Result{}, nil } - return session.Execute(ctx, sql, bindVariables) + t := time.Now() + qr, err = session.Execute(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType), sql, bindVariables) + logSql := sql + if len(logSql) > 40 { + logSql = logSql[:40] + } + + fmt.Printf("Execute %s [%s]\n", logSql, time.Since(t)) + return qr, err } func (proxy *VTGateProxy) StreamExecute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { - stream, err := session.StreamExecute(ctx, sql, bindVariables) + stream, err := session.StreamExecute(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType), sql, bindVariables) if err != nil { return err }