diff --git a/go/test/endtoend/vtorc/general/vtorc_test.go b/go/test/endtoend/vtorc/general/vtorc_test.go index d79e2964f3e..38bc5f34df9 100644 --- a/go/test/endtoend/vtorc/general/vtorc_test.go +++ b/go/test/endtoend/vtorc/general/vtorc_test.go @@ -495,3 +495,76 @@ func TestDurabilityPolicySetLater(t *testing.T) { assert.NotNil(t, primary, "should have elected a primary") utils.CheckReplication(t, newCluster, primary, shard0.Vttablets, 10*time.Second) } + +func TestFullStatusConnectionPooling(t *testing.T) { + defer utils.PrintVTOrcLogsOnFailure(t, clusterInfo.ClusterInstance) + defer cluster.PanicHandler(t) + utils.SetupVttabletsAndVTOrcs(t, clusterInfo, 4, 0, []string{ + "--tablet_manager_grpc_concurrency=1", + }, cluster.VTOrcConfiguration{ + PreventCrossDataCenterPrimaryFailover: true, + }, 1, "") + keyspace := &clusterInfo.ClusterInstance.Keyspaces[0] + shard0 := &keyspace.Shards[0] + vtorc := clusterInfo.ClusterInstance.VTOrcProcesses[0] + + // find primary from topo + curPrimary := utils.ShardPrimaryTablet(t, clusterInfo, keyspace, shard0) + assert.NotNil(t, curPrimary, "should have elected a primary") + vtOrcProcess := clusterInfo.ClusterInstance.VTOrcProcesses[0] + utils.WaitForSuccessfulRecoveryCount(t, vtOrcProcess, logic.ElectNewPrimaryRecoveryName, 1) + utils.WaitForSuccessfulPRSCount(t, vtOrcProcess, keyspace.Name, shard0.Name, 1) + + // Kill the current primary. + _ = curPrimary.VttabletProcess.Kill() + + // Wait until VTOrc notices some problems + status, resp := utils.MakeAPICallRetry(t, vtorc, "/api/replication-analysis", func(_ int, response string) bool { + return response == "null" + }) + assert.Equal(t, 200, status) + assert.Contains(t, resp, "UnreachablePrimary") + + time.Sleep(1 * time.Minute) + + // Change the primaries ports and restart it. + curPrimary.VttabletProcess.Port = clusterInfo.ClusterInstance.GetAndReservePort() + curPrimary.VttabletProcess.GrpcPort = clusterInfo.ClusterInstance.GetAndReservePort() + err := curPrimary.VttabletProcess.Setup() + require.NoError(t, err) + + // See that VTOrc eventually reports no errors. + // Wait until there are no problems and the api endpoint returns null + status, resp = utils.MakeAPICallRetry(t, vtorc, "/api/replication-analysis", func(_ int, response string) bool { + return response != "null" + }) + assert.Equal(t, 200, status) + assert.Equal(t, "null", resp) + + // REPEATED + // Kill the current primary. + _ = curPrimary.VttabletProcess.Kill() + + // Wait until VTOrc notices some problems + status, resp = utils.MakeAPICallRetry(t, vtorc, "/api/replication-analysis", func(_ int, response string) bool { + return response == "null" + }) + assert.Equal(t, 200, status) + assert.Contains(t, resp, "UnreachablePrimary") + + time.Sleep(1 * time.Minute) + + // Change the primaries ports back to original and restart it. + curPrimary.VttabletProcess.Port = curPrimary.HTTPPort + curPrimary.VttabletProcess.GrpcPort = curPrimary.GrpcPort + err = curPrimary.VttabletProcess.Setup() + require.NoError(t, err) + + // See that VTOrc eventually reports no errors. + // Wait until there are no problems and the api endpoint returns null + status, resp = utils.MakeAPICallRetry(t, vtorc, "/api/replication-analysis", func(_ int, response string) bool { + return response != "null" + }) + assert.Equal(t, 200, status) + assert.Equal(t, "null", resp) +} diff --git a/go/test/endtoend/vtorc/utils/utils.go b/go/test/endtoend/vtorc/utils/utils.go index dca2c7b1e26..00f75740338 100644 --- a/go/test/endtoend/vtorc/utils/utils.go +++ b/go/test/endtoend/vtorc/utils/utils.go @@ -733,7 +733,7 @@ func MakeAPICall(t *testing.T, vtorc *cluster.VTOrcProcess, url string) (status // The function provided takes in the status and response and returns if we should continue to retry or not func MakeAPICallRetry(t *testing.T, vtorc *cluster.VTOrcProcess, url string, retry func(int, string) bool) (status int, response string) { t.Helper() - timeout := time.After(10 * time.Second) + timeout := time.After(30 * time.Second) for { select { case <-timeout: diff --git a/go/vt/vttablet/grpctmclient/client.go b/go/vt/vttablet/grpctmclient/client.go index d8ae032bd74..dac6c7e0822 100644 --- a/go/vt/vttablet/grpctmclient/client.go +++ b/go/vt/vttablet/grpctmclient/client.go @@ -45,6 +45,15 @@ import ( topodatapb "vitess.io/vitess/go/vt/proto/topodata" ) +type DialPoolGroup int + +const ( + dialPoolGroupThrottler DialPoolGroup = iota + dialPoolGroupVTOrc +) + +type invalidatorFunc func() + var ( concurrency = 8 cert string @@ -92,14 +101,17 @@ type tmc struct { client tabletmanagerservicepb.TabletManagerClient } +type addrTmcMap map[string]*tmc + // grpcClient implements both dialer and poolDialer. type grpcClient struct { // This cache of connections is to maximize QPS for ExecuteFetchAs{Dba,App}, // CheckThrottler and FullStatus. Note we'll keep the clients open and close them upon Close() only. // But that's OK because usually the tasks that use them are one-purpose only. // The map is protected by the mutex. - mu sync.Mutex - rpcClientMap map[string]chan *tmc + mu sync.Mutex + rpcClientMap map[string]chan *tmc + rpcDialPoolMap map[DialPoolGroup]addrTmcMap } type dialer interface { @@ -109,6 +121,7 @@ type dialer interface { type poolDialer interface { dialPool(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, error) + dialDedicatedPool(ctx context.Context, dialPoolGroup DialPoolGroup, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, invalidatorFunc, error) } // Client implements tmclient.TabletManagerClient. @@ -152,6 +165,17 @@ func (client *grpcClient) dial(ctx context.Context, tablet *topodatapb.Tablet) ( return tabletmanagerservicepb.NewTabletManagerClient(cc), cc, nil } +func (client *grpcClient) createTmc(addr string, opt grpc.DialOption) (*tmc, error) { + cc, err := grpcclient.Dial(addr, grpcclient.FailFast(false), opt) + if err != nil { + return nil, err + } + return &tmc{ + cc: cc, + client: tabletmanagerservicepb.NewTabletManagerClient(cc), + }, nil +} + func (client *grpcClient) dialPool(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, error) { addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"])) opt, err := grpcclient.SecureDialOption(cert, key, ca, crl, name) @@ -170,14 +194,11 @@ func (client *grpcClient) dialPool(ctx context.Context, tablet *topodatapb.Table client.mu.Unlock() for i := 0; i < cap(c); i++ { - cc, err := grpcclient.Dial(addr, grpcclient.FailFast(false), opt) + tm, err := client.createTmc(addr, opt) if err != nil { return nil, err } - c <- &tmc{ - cc: cc, - client: tabletmanagerservicepb.NewTabletManagerClient(cc), - } + c <- tm } } else { client.mu.Unlock() @@ -188,6 +209,38 @@ func (client *grpcClient) dialPool(ctx context.Context, tablet *topodatapb.Table return result.client, nil } +func (client *grpcClient) dialDedicatedPool(ctx context.Context, dialPoolGroup DialPoolGroup, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, invalidatorFunc, error) { + addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"])) + opt, err := grpcclient.SecureDialOption(cert, key, ca, crl, name) + if err != nil { + return nil, nil, err + } + + client.mu.Lock() + defer client.mu.Unlock() + if client.rpcDialPoolMap == nil { + client.rpcDialPoolMap = make(map[DialPoolGroup]addrTmcMap) + } + if _, ok := client.rpcDialPoolMap[dialPoolGroup]; !ok { + client.rpcDialPoolMap[dialPoolGroup] = make(addrTmcMap) + } + m := client.rpcDialPoolMap[dialPoolGroup] + if _, ok := m[addr]; !ok { + tm, err := client.createTmc(addr, opt) + if err != nil { + return nil, nil, err + } + m[addr] = tm + } + invalidator := func() { + client.mu.Lock() + defer client.mu.Unlock() + m[addr].cc.Close() + delete(m, addr) + } + return m[addr].client, invalidator, nil +} + // Close is part of the tmclient.TabletManagerClient interface. func (client *grpcClient) Close() { client.mu.Lock() @@ -611,9 +664,10 @@ func (client *Client) ReplicationStatus(ctx context.Context, tablet *topodatapb. // and dialing the other tablet every time is not practical. func (client *Client) FullStatus(ctx context.Context, tablet *topodatapb.Tablet) (*replicationdatapb.FullStatus, error) { var c tabletmanagerservicepb.TabletManagerClient + var invalidator invalidatorFunc var err error if poolDialer, ok := client.dialer.(poolDialer); ok { - c, err = poolDialer.dialPool(ctx, tablet) + c, invalidator, err = poolDialer.dialDedicatedPool(ctx, dialPoolGroupVTOrc, tablet) if err != nil { return nil, err } @@ -630,6 +684,9 @@ func (client *Client) FullStatus(ctx context.Context, tablet *topodatapb.Tablet) response, err := c.FullStatus(ctx, &tabletmanagerdatapb.FullStatusRequest{}) if err != nil { + if invalidator != nil { + invalidator() + } return nil, err } return response.Status, nil @@ -1101,9 +1158,10 @@ func (client *Client) Backup(ctx context.Context, tablet *topodatapb.Tablet, req // and dialing the other tablet every time is not practical. func (client *Client) CheckThrottler(ctx context.Context, tablet *topodatapb.Tablet, req *tabletmanagerdatapb.CheckThrottlerRequest) (*tabletmanagerdatapb.CheckThrottlerResponse, error) { var c tabletmanagerservicepb.TabletManagerClient + var invalidator invalidatorFunc var err error if poolDialer, ok := client.dialer.(poolDialer); ok { - c, err = poolDialer.dialPool(ctx, tablet) + c, invalidator, err = poolDialer.dialDedicatedPool(ctx, dialPoolGroupThrottler, tablet) if err != nil { return nil, err } @@ -1120,6 +1178,9 @@ func (client *Client) CheckThrottler(ctx context.Context, tablet *topodatapb.Tab response, err := c.CheckThrottler(ctx, req) if err != nil { + if invalidator != nil { + invalidator() + } return nil, err } return response, nil diff --git a/go/vt/vttablet/grpctmclient/client_test.go b/go/vt/vttablet/grpctmclient/client_test.go new file mode 100644 index 00000000000..f842b216d8c --- /dev/null +++ b/go/vt/vttablet/grpctmclient/client_test.go @@ -0,0 +1,177 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpctmclient + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/connectivity" + + "vitess.io/vitess/go/netutil" + tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" +) + +func TestDialDedicatedPool(t *testing.T) { + ctx := context.Background() + client := NewClient() + tablet := &topodatapb.Tablet{ + Hostname: "localhost", + PortMap: map[string]int32{ + "grpc": 15991, + }, + } + addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"])) + t.Run("dialPool", func(t *testing.T) { + poolDialer, ok := client.dialer.(poolDialer) + require.True(t, ok) + + cli, invalidator, err := poolDialer.dialDedicatedPool(ctx, dialPoolGroupThrottler, tablet) + assert.NoError(t, err) + assert.NotNil(t, invalidator) + assert.NotNil(t, cli) + }) + + var cachedTmc *tmc + t.Run("maps", func(t *testing.T) { + rpcClient, ok := client.dialer.(*grpcClient) + require.True(t, ok) + assert.NotEmpty(t, rpcClient.rpcDialPoolMap) + assert.NotEmpty(t, rpcClient.rpcDialPoolMap[dialPoolGroupThrottler]) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupVTOrc]) + + c := rpcClient.rpcDialPoolMap[dialPoolGroupThrottler][addr] + assert.NotNil(t, c) + assert.Contains(t, []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, c.cc.GetState()) + + cachedTmc = c + }) + + t.Run("CheckThrottler", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + req := &tabletmanagerdatapb.CheckThrottlerRequest{} + _, err := client.CheckThrottler(ctx, tablet, req) + assert.Error(t, err) + }) + t.Run("empty map", func(t *testing.T) { + rpcClient, ok := client.dialer.(*grpcClient) + require.True(t, ok) + assert.NotEmpty(t, rpcClient.rpcDialPoolMap) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupThrottler]) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupVTOrc]) + + assert.Equal(t, connectivity.Shutdown, cachedTmc.cc.GetState()) + }) +} + +func TestDialPool(t *testing.T) { + ctx := context.Background() + client := NewClient() + tablet := &topodatapb.Tablet{ + Hostname: "localhost", + PortMap: map[string]int32{ + "grpc": 15991, + }, + } + addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"])) + t.Run("dialPool", func(t *testing.T) { + poolDialer, ok := client.dialer.(poolDialer) + require.True(t, ok) + + cli, err := poolDialer.dialPool(ctx, tablet) + assert.NoError(t, err) + assert.NotNil(t, cli) + }) + + var cachedTmc *tmc + t.Run("maps", func(t *testing.T) { + rpcClient, ok := client.dialer.(*grpcClient) + require.True(t, ok) + assert.Empty(t, rpcClient.rpcDialPoolMap) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupThrottler]) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupVTOrc]) + + assert.NotEmpty(t, rpcClient.rpcClientMap) + assert.NotEmpty(t, rpcClient.rpcClientMap[addr]) + + ch := rpcClient.rpcClientMap[addr] + cachedTmc = <-ch + ch <- cachedTmc + + assert.NotNil(t, cachedTmc) + assert.Contains(t, []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) + }) + + t.Run("CheckThrottler", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + req := &tabletmanagerdatapb.CheckThrottlerRequest{} + _, err := client.CheckThrottler(ctx, tablet, req) + assert.Error(t, err) + }) + t.Run("post throttler maps", func(t *testing.T) { + rpcClient, ok := client.dialer.(*grpcClient) + require.True(t, ok) + + rpcClient.mu.Lock() + defer rpcClient.mu.Unlock() + + assert.NotEmpty(t, rpcClient.rpcDialPoolMap) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupThrottler]) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupVTOrc]) + + assert.NotEmpty(t, rpcClient.rpcClientMap) + assert.NotEmpty(t, rpcClient.rpcClientMap[addr]) + + assert.Contains(t, []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) + }) + t.Run("ExecuteFetchAsDba", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + req := &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{} + _, err := client.ExecuteFetchAsDba(ctx, tablet, true, req) + assert.Error(t, err) + }) + + t.Run("post ExecuteFetchAsDba maps", func(t *testing.T) { + + rpcClient, ok := client.dialer.(*grpcClient) + require.True(t, ok) + + rpcClient.mu.Lock() + defer rpcClient.mu.Unlock() + + assert.NotEmpty(t, rpcClient.rpcDialPoolMap) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupThrottler]) + assert.Empty(t, rpcClient.rpcDialPoolMap[dialPoolGroupVTOrc]) + + // The default pools are unaffected. Invalidator does not run, connections are not closed. + assert.NotEmpty(t, rpcClient.rpcClientMap) + assert.NotEmpty(t, rpcClient.rpcClientMap[addr]) + + assert.NotNil(t, cachedTmc) + assert.Contains(t, []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) + }) +}