From 26c8de745aed4df06f8df080b9e6147161c624de Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Mon, 4 Nov 2024 10:43:08 -0500 Subject: [PATCH] [v16] add a hidden `--tunnel` flag to `tsh db connect` to force local proxy tunnel (#48318) * add a hidden `--tunnel` flag to `tsh db connect` to force local proxy with tunnel * move msg to const --- tool/tsh/common/db.go | 15 ++++++++++++-- tool/tsh/common/db_test.go | 40 +++++++++++++++++++++++++++++--------- tool/tsh/common/tsh.go | 1 + 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 738914fa86e4e..ad68f80cb5fcb 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -759,7 +759,7 @@ func onDatabaseConnect(cf *CLIConf) error { return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, dbInfo.RouteToDatabase)) } - requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase) + requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase, cf.LocalProxyTunnel) if err := maybeDatabaseLogin(cf, tc, profile, dbInfo, requires); err != nil { return trace.Wrap(err) } @@ -1674,8 +1674,13 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route tlsca.RouteToDa return &out } -func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase) *dbLocalProxyRequirement { +func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase, tunnelFlag bool) *dbLocalProxyRequirement { r := getDBLocalProxyRequirement(tc, route) + // Forces local proxy tunnel when --tunnel is on. + if !r.tunnel && tunnelFlag { + r.addLocalProxyWithTunnel(dbConnectRequireReasonTunnelFlag) + } + // Forces local proxy when cluster has TLS routing enabled. if !r.localProxy && tc.TLSRoutingEnabled { r.addLocalProxy(formatTLSRoutingReason(tc.SiteName)) } @@ -1851,6 +1856,12 @@ const ( dbFormatYAML = "yaml" ) +const ( + // dbConnectRequireReasonTunnelFlag is the reason used in local proxy + // requirement calculation when --tunnel flag is specified. + dbConnectRequireReasonTunnelFlag = "--tunnel flag is specified" +) + var ( // dbCmdUnsupportedTemplate is the error message printed when some // database subcommands are not supported. diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index ccb239057cf3d..b5ca387e711e5 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -578,11 +578,13 @@ func TestLocalProxyRequirement(t *testing.T) { defaultAuthPref, err := authServer.GetAuthPreference(ctx) require.NoError(t, err) tests := map[string]struct { - clusterAuthPref types.AuthPreference - route *tlsca.RouteToDatabase - setupTC func(*client.TeleportClient) - wantLocalProxy bool - wantTunnel bool + clusterAuthPref types.AuthPreference + route *tlsca.RouteToDatabase + setupTC func(*client.TeleportClient) + tunnelFlag bool + wantLocalProxy bool + wantTunnel bool + wantTunnelReason string }{ "tunnel not required": { clusterAuthPref: defaultAuthPref, @@ -600,8 +602,9 @@ func TestLocalProxyRequirement(t *testing.T) { RequireMFAType: types.RequireMFAType_SESSION, }, }, - wantLocalProxy: true, - wantTunnel: true, + wantLocalProxy: true, + wantTunnel: true, + wantTunnelReason: "MFA is required", }, "local proxy not required for separate port": { clusterAuthPref: defaultAuthPref, @@ -622,6 +625,25 @@ func TestLocalProxyRequirement(t *testing.T) { wantLocalProxy: true, wantTunnel: false, }, + "tunnel required by tunnel flag": { + clusterAuthPref: defaultAuthPref, + tunnelFlag: true, + wantLocalProxy: true, + wantTunnel: true, + wantTunnelReason: dbConnectRequireReasonTunnelFlag, + }, + "tunnel required for separate port by tunnel flag": { + clusterAuthPref: defaultAuthPref, + setupTC: func(tc *client.TeleportClient) { + tc.TLSRoutingEnabled = false + tc.TLSRoutingConnUpgradeRequired = false + tc.PostgresProxyAddr = "separate.postgres.hostport:8888" + }, + tunnelFlag: true, + wantLocalProxy: true, + wantTunnel: true, + wantTunnelReason: dbConnectRequireReasonTunnelFlag, + }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { @@ -648,12 +670,12 @@ func TestLocalProxyRequirement(t *testing.T) { Username: "alice", Database: "postgres", } - requires := getDBConnectLocalProxyRequirement(ctx, tc, route) + requires := getDBConnectLocalProxyRequirement(ctx, tc, route, tt.tunnelFlag) require.Equal(t, tt.wantLocalProxy, requires.localProxy) require.Equal(t, tt.wantTunnel, requires.tunnel) if requires.tunnel { require.Len(t, requires.tunnelReasons, 1) - require.Contains(t, requires.tunnelReasons[0], "MFA is required") + require.Contains(t, requires.tunnelReasons[0], tt.wantTunnelReason) } }) } diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index fd3b4a392f053..534c5f37fd1cb 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -994,6 +994,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) + dbConnect.Flag("tunnel", "Open authenticated tunnel using database's client certificate so clients don't need to authenticate").Hidden().BoolVar(&cf.LocalProxyTunnel) // join join := app.Command("join", "Join the active SSH or Kubernetes session.")