diff --git a/examples/local/vstream_client.go b/examples/local/vstream_client.go index 48d23247086..98d2129f898 100644 --- a/examples/local/vstream_client.go +++ b/examples/local/vstream_client.go @@ -67,7 +67,7 @@ func main() { Filter: "select * from customer", }}, } - conn, err := vtgateconn.Dial("localhost:15991") + conn, err := vtgateconn.Dial(ctx, "localhost:15991") if err != nil { log.Fatal(err) } diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index d374a245560..3ef4e8a1b3b 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -920,7 +920,7 @@ func (cluster *LocalProcessCluster) ExecOnVTGate(ctx context.Context, addr strin return nil, err } - conn, err := vtgateconn.Dial(addr) + conn, err := vtgateconn.Dial(ctx, addr) if err != nil { return nil, err } diff --git a/go/test/endtoend/cluster/cluster_util.go b/go/test/endtoend/cluster/cluster_util.go index 057ff5e9c9c..061e632dde7 100644 --- a/go/test/endtoend/cluster/cluster_util.go +++ b/go/test/endtoend/cluster/cluster_util.go @@ -482,13 +482,13 @@ func WaitForHealthyShard(vtctldclient *VtctldClientProcess, keyspace, shard stri } // DialVTGate returns a VTGate grpc connection. -func DialVTGate(name, addr, username, password string) (*vtgateconn.VTGateConn, error) { +func DialVTGate(ctx context.Context, name, addr, username, password string) (*vtgateconn.VTGateConn, error) { clientCreds := &grpcclient.StaticAuthClientCreds{Username: username, Password: password} creds := grpc.WithPerRPCCredentials(clientCreds) dialerFunc := grpcvtgateconn.Dial(creds) dialerName := name vtgateconn.RegisterDialer(dialerName, dialerFunc) - return vtgateconn.DialProtocol(dialerName, addr) + return vtgateconn.DialProtocol(ctx, dialerName, addr) } // PrintFiles prints the files that are asked for. If no file is specified, all the files are printed. diff --git a/go/test/endtoend/messaging/message_test.go b/go/test/endtoend/messaging/message_test.go index 32c4401a1da..7e1190c16bb 100644 --- a/go/test/endtoend/messaging/message_test.go +++ b/go/test/endtoend/messaging/message_test.go @@ -573,7 +573,7 @@ func VtgateGrpcConn(ctx context.Context, cluster *cluster.LocalProcessCluster) ( stream := new(VTGateStream) stream.ctx = ctx stream.host = fmt.Sprintf("%s:%d", cluster.Hostname, cluster.VtgateProcess.GrpcPort) - conn, err := vtgateconn.Dial(stream.host) + conn, err := vtgateconn.Dial(ctx, stream.host) // init components stream.respChan = make(chan *sqltypes.Result) stream.VTGateConn = conn diff --git a/go/test/endtoend/recovery/unshardedrecovery/recovery.go b/go/test/endtoend/recovery/unshardedrecovery/recovery.go index cd3d5bd9f04..1ebb7c2647f 100644 --- a/go/test/endtoend/recovery/unshardedrecovery/recovery.go +++ b/go/test/endtoend/recovery/unshardedrecovery/recovery.go @@ -308,7 +308,7 @@ func TestRecoveryImpl(t *testing.T) { // Build vtgate grpc connection grpcAddress := fmt.Sprintf("%s:%d", localCluster.Hostname, localCluster.VtgateGrpcPort) - vtgateConn, err := vtgateconn.Dial(grpcAddress) + vtgateConn, err := vtgateconn.Dial(context.Background(), grpcAddress) assert.NoError(t, err) defer vtgateConn.Close() session := vtgateConn.Session("@replica", nil) diff --git a/go/test/endtoend/tabletgateway/vtgate_test.go b/go/test/endtoend/tabletgateway/vtgate_test.go index 0c7b68b67e2..a48a22f2cb0 100644 --- a/go/test/endtoend/tabletgateway/vtgate_test.go +++ b/go/test/endtoend/tabletgateway/vtgate_test.go @@ -302,7 +302,7 @@ func TestStreamingRPCStuck(t *testing.T) { } // Connect to vtgate and run a streaming query. - vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "test_user", "") + vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "test_user", "") require.NoError(t, err) stream, err := vtgateConn.Session("", &querypb.ExecuteOptions{}).StreamExecute(ctx, "select * from customer", map[string]*querypb.BindVariable{}) require.NoError(t, err) diff --git a/go/test/endtoend/vreplication/vreplication_test.go b/go/test/endtoend/vreplication/vreplication_test.go index 5005b2b0acc..4c72781df29 100644 --- a/go/test/endtoend/vreplication/vreplication_test.go +++ b/go/test/endtoend/vreplication/vreplication_test.go @@ -513,7 +513,7 @@ func testVStreamCellFlag(t *testing.T) { for _, tc := range vstreamTestCases { t.Run("VStreamCellsFlag/"+tc.cells, func(t *testing.T) { - conn, err := vtgateconn.Dial(fmt.Sprintf("localhost:%d", vc.ClusterConfig.vtgateGrpcPort)) + conn, err := vtgateconn.Dial(ctx, fmt.Sprintf("localhost:%d", vc.ClusterConfig.vtgateGrpcPort)) require.NoError(t, err) defer conn.Close() diff --git a/go/test/endtoend/vreplication/vschema_load_test.go b/go/test/endtoend/vreplication/vschema_load_test.go index a5c414ad3a0..6ca8dcfe472 100644 --- a/go/test/endtoend/vreplication/vschema_load_test.go +++ b/go/test/endtoend/vreplication/vschema_load_test.go @@ -94,7 +94,7 @@ func TestVSchemaChangesUnderLoad(t *testing.T) { Filter: "select * from customer", }}, } - conn, err := vtgateconn.Dial(net.JoinHostPort("localhost", strconv.Itoa(vc.ClusterConfig.vtgateGrpcPort))) + conn, err := vtgateconn.Dial(ctx, net.JoinHostPort("localhost", strconv.Itoa(vc.ClusterConfig.vtgateGrpcPort))) require.NoError(t, err) defer conn.Close() diff --git a/go/test/endtoend/vreplication/vstream_test.go b/go/test/endtoend/vreplication/vstream_test.go index 48fac9b0e00..e13c3e24e80 100644 --- a/go/test/endtoend/vreplication/vstream_test.go +++ b/go/test/endtoend/vreplication/vstream_test.go @@ -58,7 +58,7 @@ func testVStreamWithFailover(t *testing.T, failover bool) { testVStreamFrom(t, vtgate, "product", 2) }) ctx := context.Background() - vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) + vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) if err != nil { log.Fatal(err) } @@ -259,7 +259,7 @@ func testVStreamStopOnReshardFlag(t *testing.T, stopOnReshard bool, baseTabletID vc.AddKeyspace(t, []*Cell{defaultCell}, "sharded", "-80,80-", vschemaSharded, schemaSharded, defaultReplicas, defaultRdonly, baseTabletID+200, nil) ctx := context.Background() - vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) + vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) if err != nil { log.Fatal(err) } @@ -398,7 +398,7 @@ func testVStreamCopyMultiKeyspaceReshard(t *testing.T, baseTabletID int) numEven require.NoError(t, err) ctx := context.Background() - vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) + vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) if err != nil { log.Fatal(err) } @@ -550,7 +550,7 @@ func TestMultiVStreamsKeyspaceReshard(t *testing.T) { defer vtgateConn.Close() verifyClusterHealth(t, vc) - vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) + vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) require.NoError(t, err) defer vstreamConn.Close() diff --git a/go/test/endtoend/vtcombo/recreate/recreate_test.go b/go/test/endtoend/vtcombo/recreate/recreate_test.go index 1d21a32fbf9..e66edb7688a 100644 --- a/go/test/endtoend/vtcombo/recreate/recreate_test.go +++ b/go/test/endtoend/vtcombo/recreate/recreate_test.go @@ -95,7 +95,7 @@ func TestMain(m *testing.M) { func TestDropAndRecreateWithSameShards(t *testing.T) { ctx := context.Background() - conn, err := vtgateconn.Dial(grpcAddress) + conn, err := vtgateconn.Dial(ctx, grpcAddress) require.Nil(t, err) defer conn.Close() diff --git a/go/test/endtoend/vtcombo/vttest_sample_test.go b/go/test/endtoend/vtcombo/vttest_sample_test.go index bb09a4ad336..daeb5e8deb9 100644 --- a/go/test/endtoend/vtcombo/vttest_sample_test.go +++ b/go/test/endtoend/vtcombo/vttest_sample_test.go @@ -131,7 +131,7 @@ func TestStandalone(t *testing.T) { require.Contains(t, tmp[0], "vtcombo") ctx := context.Background() - conn, err := vtgateconn.Dial(grpcAddress) + conn, err := vtgateconn.Dial(ctx, grpcAddress) require.NoError(t, err) defer conn.Close() diff --git a/go/test/endtoend/vtgate/foreignkey/fk_test.go b/go/test/endtoend/vtgate/foreignkey/fk_test.go index 509dc3c88b9..5a34a2b49c0 100644 --- a/go/test/endtoend/vtgate/foreignkey/fk_test.go +++ b/go/test/endtoend/vtgate/foreignkey/fk_test.go @@ -182,7 +182,7 @@ func TestUpdateWithFK(t *testing.T) { // TestVstreamForFKBinLog tests that dml queries with fks are written with child row first approach in the binary logs. func TestVstreamForFKBinLog(t *testing.T) { - vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "fk_user", "") + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "fk_user", "") require.NoError(t, err) defer vtgateConn.Close() diff --git a/go/test/endtoend/vtgate/grpc_api/acl_test.go b/go/test/endtoend/vtgate/grpc_api/acl_test.go index a5957523924..2819a3e41d1 100644 --- a/go/test/endtoend/vtgate/grpc_api/acl_test.go +++ b/go/test/endtoend/vtgate/grpc_api/acl_test.go @@ -32,7 +32,7 @@ func TestEffectiveCallerIDWithAccess(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "some_other_user", "test_password") + vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "some_other_user", "test_password") require.NoError(t, err) defer vtgateConn.Close() @@ -48,7 +48,7 @@ func TestEffectiveCallerIDWithNoAccess(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "another_unrelated_user", "test_password") + vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "another_unrelated_user", "test_password") require.NoError(t, err) defer vtgateConn.Close() @@ -66,7 +66,7 @@ func TestAuthenticatedUserWithAccess(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_with_access", "test_password") + vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_with_access", "test_password") require.NoError(t, err) defer vtgateConn.Close() @@ -81,7 +81,7 @@ func TestAuthenticatedUserNoAccess(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_no_access", "test_password") + vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_no_access", "test_password") require.NoError(t, err) defer vtgateConn.Close() @@ -98,7 +98,7 @@ func TestUnauthenticatedUser(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "", "") + vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "", "") require.NoError(t, err) defer vtgateConn.Close() diff --git a/go/test/endtoend/vtgate/grpc_api/execute_test.go b/go/test/endtoend/vtgate/grpc_api/execute_test.go index 2d57e065cdd..b1a5f3b8d80 100644 --- a/go/test/endtoend/vtgate/grpc_api/execute_test.go +++ b/go/test/endtoend/vtgate/grpc_api/execute_test.go @@ -38,7 +38,7 @@ func TestTransactionsWithGRPCAPI(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_with_access", "test_password") + vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_with_access", "test_password") require.NoError(t, err) defer vtgateConn.Close() diff --git a/go/test/endtoend/vtgate/queries/reference/main_test.go b/go/test/endtoend/vtgate/queries/reference/main_test.go index 8953573690a..4c9440ca4ff 100644 --- a/go/test/endtoend/vtgate/queries/reference/main_test.go +++ b/go/test/endtoend/vtgate/queries/reference/main_test.go @@ -156,7 +156,7 @@ func TestMain(m *testing.M) { go func() { ctx := context.Background() vtgateAddr := fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateProcess.GrpcPort) - vtgateConn, err := vtgateconn.Dial(vtgateAddr) + vtgateConn, err := vtgateconn.Dial(ctx, vtgateAddr) if err != nil { done <- false return @@ -234,7 +234,7 @@ func TestMain(m *testing.M) { ctx := context.Background() vtgateAddr := fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateProcess.GrpcPort) - vtgateConn, err := vtgateconn.Dial(vtgateAddr) + vtgateConn, err := vtgateconn.Dial(ctx, vtgateAddr) if err != nil { return 1 } diff --git a/go/vt/grpcclient/client.go b/go/vt/grpcclient/client.go index 457ac34de76..7524298514e 100644 --- a/go/vt/grpcclient/client.go +++ b/go/vt/grpcclient/client.go @@ -19,6 +19,7 @@ limitations under the License. package grpcclient import ( + "context" "crypto/tls" "sync" "time" @@ -96,6 +97,16 @@ func RegisterGRPCDialOptions(grpcDialOptionsFunc func(opts []grpc.DialOption) ([ // failFast is a non-optional parameter because callers are required to specify // what that should be. func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return DialContext(context.Background(), target, failFast, opts...) +} + +// DialContext creates a grpc connection to the given target. Setup steps are +// covered by the context deadline, and, if WithBlock is specified in the dial +// options, connection establishment steps are covered by the context as well. +// +// failFast is a non-optional parameter because callers are required to specify +// what that should be. +func DialContext(ctx context.Context, target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) { msgSize := grpccommon.MaxMessageSize() newopts := []grpc.DialOption{ grpc.WithDefaultCallOptions( @@ -138,7 +149,7 @@ func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.Clie newopts = append(newopts, interceptors()...) - return grpc.Dial(target, newopts...) + return grpc.DialContext(ctx, target, newopts...) } func interceptors() []grpc.DialOption { diff --git a/go/vt/grpcoptionaltls/server_test.go b/go/vt/grpcoptionaltls/server_test.go index 32fdfdc154d..e419294b172 100755 --- a/go/vt/grpcoptionaltls/server_test.go +++ b/go/vt/grpcoptionaltls/server_test.go @@ -97,7 +97,7 @@ func TestOptionalTLS(t *testing.T) { testFunc := func(t *testing.T, dialOpt grpc.DialOption) { ctx, cancel := context.WithTimeout(testCtx, 5*time.Second) defer cancel() - conn, err := grpc.NewClient(addr, dialOpt) + conn, err := grpc.DialContext(ctx, addr, dialOpt) if err != nil { t.Fatalf("failed to connect to the server %v", err) } diff --git a/go/vt/vitessdriver/driver.go b/go/vt/vitessdriver/driver.go index 554c5efc9cb..4a965399e9c 100644 --- a/go/vt/vitessdriver/driver.go +++ b/go/vt/vitessdriver/driver.go @@ -174,13 +174,13 @@ func (d drv) newConnector(cfg Configuration) (driver.Connector, error) { } // Connect implements the database/sql/driver.Connector interface. -func (c *connector) Connect(_ context.Context) (driver.Conn, error) { +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { conn := &conn{ cfg: c.cfg, convert: c.convert, } - if err := conn.dial(); err != nil { + if err := conn.dial(ctx); err != nil { return nil, err } @@ -267,9 +267,9 @@ type conn struct { session *vtgateconn.VTGateSession } -func (c *conn) dial() error { +func (c *conn) dial(ctx context.Context) error { var err error - c.conn, err = vtgateconn.DialProtocol(c.cfg.Protocol, c.cfg.Address) + c.conn, err = vtgateconn.DialProtocol(ctx, c.cfg.Protocol, c.cfg.Address) if err != nil { return err } diff --git a/go/vt/vtadmin/grpcserver/server_test.go b/go/vt/vtadmin/grpcserver/server_test.go index 53f0cbcd48c..4f43c4413ce 100644 --- a/go/vt/vtadmin/grpcserver/server_test.go +++ b/go/vt/vtadmin/grpcserver/server_test.go @@ -64,7 +64,7 @@ func TestServer(t *testing.T) { } close(readyCh) - conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) assert.NoError(t, err) defer conn.Close() diff --git a/go/vt/vtgate/endtoend/vstream_test.go b/go/vt/vtgate/endtoend/vstream_test.go index 0c3c6a6f2cd..871e6cf98c3 100644 --- a/go/vt/vtgate/endtoend/vstream_test.go +++ b/go/vt/vtgate/endtoend/vstream_test.go @@ -41,7 +41,7 @@ import ( ) func initialize(ctx context.Context, t *testing.T) (*vtgateconn.VTGateConn, *mysql.Conn, *mysql.Conn, func()) { - gconn, err := vtgateconn.Dial(grpcAddress) + gconn, err := vtgateconn.Dial(ctx, grpcAddress) if err != nil { t.Fatal(err) } diff --git a/go/vt/vtgate/fakerpcvtgateconn/conn.go b/go/vt/vtgate/fakerpcvtgateconn/conn.go index 3c2ed352f14..3f6236ea9ec 100644 --- a/go/vt/vtgate/fakerpcvtgateconn/conn.go +++ b/go/vt/vtgate/fakerpcvtgateconn/conn.go @@ -61,7 +61,7 @@ func RegisterFakeVTGateConnDialer() (*FakeVTGateConn, string) { impl := &FakeVTGateConn{ execMap: make(map[string]*queryResponse), } - vtgateconn.RegisterDialer(protocol, func(address string) (vtgateconn.Impl, error) { + vtgateconn.RegisterDialer(protocol, func(ctx context.Context, address string) (vtgateconn.Impl, error) { return impl, nil }) return impl, protocol diff --git a/go/vt/vtgate/grpcvtgateconn/conn.go b/go/vt/vtgate/grpcvtgateconn/conn.go index d2b40aef6b7..a681e3661cd 100644 --- a/go/vt/vtgate/grpcvtgateconn/conn.go +++ b/go/vt/vtgate/grpcvtgateconn/conn.go @@ -72,13 +72,13 @@ type vtgateConn struct { c vtgateservicepb.VitessClient } -func dial(addr string) (vtgateconn.Impl, error) { - return Dial()(addr) +func dial(ctx context.Context, addr string) (vtgateconn.Impl, error) { + return Dial()(ctx, addr) } // Dial produces a vtgateconn.DialerFunc with custom options. func Dial(opts ...grpc.DialOption) vtgateconn.DialerFunc { - return func(address string) (vtgateconn.Impl, error) { + return func(ctx context.Context, address string) (vtgateconn.Impl, error) { opt, err := grpcclient.SecureDialOption(cert, key, ca, crl, name) if err != nil { return nil, err @@ -86,7 +86,7 @@ func Dial(opts ...grpc.DialOption) vtgateconn.DialerFunc { opts = append(opts, opt) - cc, err := grpcclient.Dial(address, grpcclient.FailFast(false), opts...) + cc, err := grpcclient.DialContext(ctx, address, grpcclient.FailFast(false), opts...) if err != nil { return nil, err } @@ -99,6 +99,14 @@ func Dial(opts ...grpc.DialOption) vtgateconn.DialerFunc { } } +// DialWithOpts allows for custom dial options to be set on a vtgateConn. +// +// Deprecated: the context parameter cannot be used by the returned +// vtgateconn.DialerFunc and thus has no effect. Use Dial instead. +func DialWithOpts(_ context.Context, opts ...grpc.DialOption) vtgateconn.DialerFunc { + return Dial(opts...) +} + func (conn *vtgateConn) Execute(ctx context.Context, session *vtgatepb.Session, query string, bindVars map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) { request := &vtgatepb.ExecuteRequest{ CallerId: callerid.EffectiveCallerIDFromContext(ctx), diff --git a/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go b/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go index 6053c1e1536..55a067807bd 100644 --- a/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go +++ b/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go @@ -50,7 +50,8 @@ func TestGRPCVTGateConn(t *testing.T) { go server.Serve(listener) // Create a Go RPC client connecting to the server - client, err := dial(listener.Addr().String()) + ctx := context.Background() + client, err := dial(ctx, listener.Addr().String()) if err != nil { t.Fatalf("dial failed: %v", err) } @@ -103,6 +104,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) { } // Create a Go RPC client connecting to the server + ctx := context.Background() fs := pflag.NewFlagSet("", pflag.ContinueOnError) grpcclient.RegisterFlags(fs) @@ -112,7 +114,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) { f.Name(), }) require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", f.Name()) - client, err := dial(listener.Addr().String()) + client, err := dial(ctx, listener.Addr().String()) if err != nil { t.Fatalf("dial failed: %v", err) } @@ -143,6 +145,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) { } // Create a Go RPC client connecting to the server + ctx = context.Background() fs = pflag.NewFlagSet("", pflag.ContinueOnError) grpcclient.RegisterFlags(fs) @@ -152,12 +155,12 @@ func TestGRPCVTGateConnAuth(t *testing.T) { f.Name(), }) require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", f.Name()) - client, err = dial(listener.Addr().String()) + client, err = dial(ctx, listener.Addr().String()) if err != nil { t.Fatalf("dial failed: %v", err) } RegisterTestDialProtocol(client) - conn, _ := vtgateconn.DialProtocol("test", "") + conn, _ := vtgateconn.DialProtocol(context.Background(), "test", "") // run the test suite _, err = conn.Session("", nil).Execute(context.Background(), "select * from t", nil) want := "rpc error: code = Unauthenticated desc = username and password must be provided" diff --git a/go/vt/vtgate/grpcvtgateconn/suite_test.go b/go/vt/vtgate/grpcvtgateconn/suite_test.go index 25d81802bf4..e5cd5c3ac81 100644 --- a/go/vt/vtgate/grpcvtgateconn/suite_test.go +++ b/go/vt/vtgate/grpcvtgateconn/suite_test.go @@ -261,7 +261,7 @@ func CreateFakeServer(t *testing.T) vtgateservice.VTGateService { // RegisterTestDialProtocol registers a vtgateconn implementation under the "test" protocol func RegisterTestDialProtocol(impl vtgateconn.Impl) { - vtgateconn.RegisterDialer("test", func(address string) (vtgateconn.Impl, error) { + vtgateconn.RegisterDialer("test", func(ctx context.Context, address string) (vtgateconn.Impl, error) { return impl, nil }) } @@ -277,10 +277,10 @@ func (f *fakeVTGateService) HandlePanic(err *error) { // RunTests runs all the tests func RunTests(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGateService) { - vtgateconn.RegisterDialer("test", func(address string) (vtgateconn.Impl, error) { + vtgateconn.RegisterDialer("test", func(ctx context.Context, address string) (vtgateconn.Impl, error) { return impl, nil }) - conn, err := vtgateconn.DialProtocol("test", "") + conn, err := vtgateconn.DialProtocol(context.Background(), "test", "") if err != nil { t.Fatalf("Got err: %v from vtgateconn.DialProtocol", err) } @@ -304,7 +304,7 @@ func RunTests(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGat // RunErrorTests runs all the tests that expect errors func RunErrorTests(t *testing.T, fakeServer vtgateservice.VTGateService) { - conn, err := vtgateconn.DialProtocol("test", "") + conn, err := vtgateconn.DialProtocol(context.Background(), "test", "") if err != nil { t.Fatalf("Got err: %v from vtgateconn.DialProtocol", err) } diff --git a/go/vt/vtgate/vtgateconn/vtgateconn.go b/go/vt/vtgate/vtgateconn/vtgateconn.go index f1f0fd77f39..ae0da3fdf43 100644 --- a/go/vt/vtgate/vtgateconn/vtgateconn.go +++ b/go/vt/vtgate/vtgateconn/vtgateconn.go @@ -190,7 +190,7 @@ type Impl interface { // DialerFunc represents a function that will return an Impl // object that can communicate with a VTGate. -type DialerFunc func(address string) (Impl, error) +type DialerFunc func(ctx context.Context, address string) (Impl, error) var ( dialers = make(map[string]DialerFunc) @@ -221,7 +221,7 @@ func DeregisterDialer(name string) { } // DialProtocol dials a specific protocol, and returns the *VTGateConn -func DialProtocol(protocol string, address string) (*VTGateConn, error) { +func DialProtocol(ctx context.Context, protocol string, address string) (*VTGateConn, error) { dialersM.Lock() dialer, ok := dialers[protocol] dialersM.Unlock() @@ -229,7 +229,7 @@ func DialProtocol(protocol string, address string) (*VTGateConn, error) { if !ok { return nil, fmt.Errorf("no dialer registered for VTGate protocol %s", protocol) } - impl, err := dialer(address) + impl, err := dialer(ctx, address) if err != nil { return nil, err } @@ -240,6 +240,6 @@ func DialProtocol(protocol string, address string) (*VTGateConn, error) { // Dial dials using the command-line specified protocol, and returns // the *VTGateConn. -func Dial(address string) (*VTGateConn, error) { - return DialProtocol(vtgateProtocol, address) +func Dial(ctx context.Context, address string) (*VTGateConn, error) { + return DialProtocol(ctx, vtgateProtocol, address) } diff --git a/go/vt/vtgate/vtgateconn/vtgateconn_test.go b/go/vt/vtgate/vtgateconn/vtgateconn_test.go index 04c696c5a8a..523492328e9 100644 --- a/go/vt/vtgate/vtgateconn/vtgateconn_test.go +++ b/go/vt/vtgate/vtgateconn/vtgateconn_test.go @@ -17,11 +17,12 @@ limitations under the License. package vtgateconn import ( + "context" "testing" ) func TestRegisterDialer(t *testing.T) { - dialerFunc := func(string) (Impl, error) { + dialerFunc := func(context.Context, string) (Impl, error) { return nil, nil } RegisterDialer("test1", dialerFunc) @@ -30,14 +31,14 @@ func TestRegisterDialer(t *testing.T) { func TestGetDialerWithProtocol(t *testing.T) { protocol := "test2" - _, err := DialProtocol(protocol, "") + _, err := DialProtocol(context.Background(), protocol, "") if err == nil || err.Error() != "no dialer registered for VTGate protocol "+protocol { t.Fatalf("protocol: %s is not registered, should return error: %v", protocol, err) } - RegisterDialer(protocol, func(string) (Impl, error) { + RegisterDialer(protocol, func(context.Context, string) (Impl, error) { return nil, nil }) - c, err := DialProtocol(protocol, "") + c, err := DialProtocol(context.Background(), protocol, "") if err != nil || c == nil { t.Fatalf("dialerFunc has been registered, should not get nil: %v %v", err, c) } @@ -46,13 +47,13 @@ func TestGetDialerWithProtocol(t *testing.T) { func TestDeregisterDialer(t *testing.T) { const protocol = "test3" - RegisterDialer(protocol, func(string) (Impl, error) { + RegisterDialer(protocol, func(context.Context, string) (Impl, error) { return nil, nil }) DeregisterDialer(protocol) - _, err := DialProtocol(protocol, "") + _, err := DialProtocol(context.Background(), protocol, "") if err == nil || err.Error() != "no dialer registered for VTGate protocol "+protocol { t.Fatalf("protocol: %s is not registered, should return error: %v", protocol, err) } diff --git a/go/vt/vttablet/endtoend/framework/server.go b/go/vt/vttablet/endtoend/framework/server.go index 94224f029fc..95c8114fd9f 100644 --- a/go/vt/vttablet/endtoend/framework/server.go +++ b/go/vt/vttablet/endtoend/framework/server.go @@ -64,7 +64,7 @@ func StartCustomServer(ctx context.Context, connParams, connAppDebugParams mysql // Setup a fake vtgate server. protocol := "resolveTest" vtgateconn.SetVTGateProtocol(protocol) - vtgateconn.RegisterDialer(protocol, func(string) (vtgateconn.Impl, error) { + vtgateconn.RegisterDialer(protocol, func(context.Context, string) (vtgateconn.Impl, error) { return &txResolver{ FakeVTGateConn: fakerpcvtgateconn.FakeVTGateConn{}, }, nil diff --git a/go/vt/vttablet/grpctmclient/cached_client.go b/go/vt/vttablet/grpctmclient/cached_client.go index 30577684048..c0dd751ec30 100644 --- a/go/vt/vttablet/grpctmclient/cached_client.go +++ b/go/vt/vttablet/grpctmclient/cached_client.go @@ -143,7 +143,7 @@ func (dialer *cachedConnDialer) dial(ctx context.Context, tablet *topodatapb.Tab dialer.connWaitSema.Release(1) return client, closer, err } - return dialer.newdial(addr) + return dialer.newdial(ctx, addr) } defer func() { @@ -156,7 +156,7 @@ func (dialer *cachedConnDialer) dial(ctx context.Context, tablet *topodatapb.Tab dialerStats.DialTimeouts.Add(1) return nil, nil, ctx.Err() default: - if client, closer, found, err := dialer.pollOnce(addr); found { + if client, closer, found, err := dialer.pollOnce(ctx, addr); found { return client, closer, err } } @@ -204,7 +204,7 @@ func (dialer *cachedConnDialer) tryFromCache(addr string, locker sync.Locker) (c // // It returns a TabletManagerClient impl, an io.Closer, a flag to indicate // whether the dial() poll loop should exit, and an error. -func (dialer *cachedConnDialer) pollOnce(addr string) (client tabletmanagerservicepb.TabletManagerClient, closer io.Closer, found bool, err error) { +func (dialer *cachedConnDialer) pollOnce(ctx context.Context, addr string) (client tabletmanagerservicepb.TabletManagerClient, closer io.Closer, found bool, err error) { dialer.m.Lock() if client, closer, found, err := dialer.tryFromCache(addr, nil); found { @@ -225,7 +225,7 @@ func (dialer *cachedConnDialer) pollOnce(addr string) (client tabletmanagerservi conn.cc.Close() dialer.m.Unlock() - client, closer, err = dialer.newdial(addr) + client, closer, err = dialer.newdial(ctx, addr) return client, closer, true, err } @@ -236,14 +236,14 @@ func (dialer *cachedConnDialer) pollOnce(addr string) (client tabletmanagerservi // // It returns the three-tuple of client-interface, closer, and error that the // main dial func returns. -func (dialer *cachedConnDialer) newdial(addr string) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) { +func (dialer *cachedConnDialer) newdial(ctx context.Context, addr string) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) { opt, err := grpcclient.SecureDialOption(cert, key, ca, crl, name) if err != nil { dialer.connWaitSema.Release(1) return nil, nil, err } - cc, err := grpcclient.Dial(addr, grpcclient.FailFast(false), opt) + cc, err := grpcclient.DialContext(ctx, addr, grpcclient.FailFast(false), opt) if err != nil { dialer.connWaitSema.Release(1) return nil, nil, err diff --git a/go/vt/vttablet/grpctmclient/client_test.go b/go/vt/vttablet/grpctmclient/client_test.go index 9c211cdc846..1487303163d 100644 --- a/go/vt/vttablet/grpctmclient/client_test.go +++ b/go/vt/vttablet/grpctmclient/client_test.go @@ -67,7 +67,7 @@ func TestDialDedicatedPool(t *testing.T) { c := rpcClient.rpcDialPoolMap[dialPoolGroupThrottler][addr] assert.NotNil(t, c) - assert.Contains(t, []connectivity.State{connectivity.Idle, connectivity.Connecting, connectivity.TransientFailure}, c.cc.GetState()) + assert.Contains(t, []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, c.cc.GetState()) cachedTmc = c }) @@ -126,7 +126,7 @@ func TestDialPool(t *testing.T) { ch <- cachedTmc assert.NotNil(t, cachedTmc) - assert.Contains(t, []connectivity.State{connectivity.Idle, connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) + assert.Contains(t, []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) }) t.Run("CheckThrottler", func(t *testing.T) { @@ -151,7 +151,7 @@ func TestDialPool(t *testing.T) { assert.NotEmpty(t, rpcClient.rpcClientMap) assert.NotEmpty(t, rpcClient.rpcClientMap[addr]) - assert.Contains(t, []connectivity.State{connectivity.Idle, connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) + assert.Contains(t, []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) }) t.Run("ExecuteFetchAsDba", func(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second) @@ -179,6 +179,6 @@ func TestDialPool(t *testing.T) { assert.NotEmpty(t, rpcClient.rpcClientMap[addr]) assert.NotNil(t, cachedTmc) - assert.Contains(t, []connectivity.State{connectivity.Idle, connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) + assert.Contains(t, []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, cachedTmc.cc.GetState()) }) } diff --git a/go/vt/vttablet/tabletserver/tx_engine.go b/go/vt/vttablet/tabletserver/tx_engine.go index cb8c3e6e51c..7e8ecc06a75 100644 --- a/go/vt/vttablet/tabletserver/tx_engine.go +++ b/go/vt/vttablet/tabletserver/tx_engine.go @@ -480,7 +480,7 @@ func (te *TxEngine) startWatchdog() { return } - coordConn, err := vtgateconn.Dial(te.coordinatorAddress) + coordConn, err := vtgateconn.Dial(ctx, te.coordinatorAddress) if err != nil { te.env.Stats().InternalErrors.Add("WatchdogFail", 1) log.Errorf("Error connecting to coordinator '%v': %v", te.coordinatorAddress, err) diff --git a/go/vt/vttablet/tabletserver/tx_executor_test.go b/go/vt/vttablet/tabletserver/tx_executor_test.go index fe34171fea5..2651eb2a6cc 100644 --- a/go/vt/vttablet/tabletserver/tx_executor_test.go +++ b/go/vt/vttablet/tabletserver/tx_executor_test.go @@ -454,7 +454,7 @@ func TestExecutorResolveTransaction(t *testing.T) { defer func() { vtgateconn.SetVTGateProtocol(oldValue) }() - vtgateconn.RegisterDialer(protocol, func(string) (vtgateconn.Impl, error) { + vtgateconn.RegisterDialer(protocol, func(context.Context, string) (vtgateconn.Impl, error) { return &FakeVTGateConn{ FakeVTGateConn: fakerpcvtgateconn.FakeVTGateConn{}, }, nil diff --git a/go/vtbench/client.go b/go/vtbench/client.go index 1dbb1c6a016..1a6751a62db 100644 --- a/go/vtbench/client.go +++ b/go/vtbench/client.go @@ -93,7 +93,7 @@ func (c *grpcVtgateConn) connect(ctx context.Context, cp ConnParams) error { conn, ok := vtgateConns[address] if !ok { var err error - conn, err = vtgateconn.DialProtocol("grpc", address) + conn, err = vtgateconn.DialProtocol(ctx, "grpc", address) if err != nil { return err } diff --git a/tools/rowlog/rowlog.go b/tools/rowlog/rowlog.go index 20c9ac7902e..8092159c6b6 100644 --- a/tools/rowlog/rowlog.go +++ b/tools/rowlog/rowlog.go @@ -154,7 +154,7 @@ func startStreaming(ctx context.Context, vtgate, vtctld, keyspace, tablet, table }}, FieldEventMode: 1, } - conn, err := vtgateconn.Dial(vtgate) + conn, err := vtgateconn.Dial(ctx, vtgate) if err != nil { log.Fatal(err) }