diff --git a/lib/tbot/service_application_tunnel.go b/lib/tbot/service_application_tunnel.go index 06dcde38faeb8..427a4558d4e5c 100644 --- a/lib/tbot/service_application_tunnel.go +++ b/lib/tbot/service_application_tunnel.go @@ -133,9 +133,9 @@ func (s *ApplicationTunnelService) buildLocalProxyConfig(ctx context.Context) (l if err != nil { return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "pinging proxy") } - proxyAddr, err := proxyPing.tlsRoutingProxyPublicAddr() + proxyAddr, err := proxyPing.proxyWebAddr() if err != nil { - return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "getting proxy address") + return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "determining proxy web addr") } s.log.DebugContext(ctx, "Issuing initial certificate for local proxy.") diff --git a/lib/tbot/service_database_tunnel.go b/lib/tbot/service_database_tunnel.go index 5082873559efc..c1fac6e38ccd8 100644 --- a/lib/tbot/service_database_tunnel.go +++ b/lib/tbot/service_database_tunnel.go @@ -93,9 +93,12 @@ func (s *DatabaseTunnelService) buildLocalProxyConfig(ctx context.Context) (lpCf if err != nil { return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "pinging proxy") } - proxyAddr, err := proxyPing.tlsRoutingProxyPublicAddr() + proxyAddr, err := proxyPing.proxyWebAddr() if err != nil { - return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "determining tls routing address") + return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "determining proxy web address") + } + if !proxyPing.Proxy.TLSRoutingEnabled { + return alpnproxy.LocalProxyConfig{}, trace.BadParameter("proxy does not support TLS routing") } // Fetch information about the database and then issue the initial diff --git a/lib/tbot/service_identity_output.go b/lib/tbot/service_identity_output.go index 148688022f0ef..de4773973480c 100644 --- a/lib/tbot/service_identity_output.go +++ b/lib/tbot/service_identity_output.go @@ -252,13 +252,9 @@ func renderSSHConfig( ) defer span.End() - proxyAddr := proxyPing.Proxy.SSH.PublicAddr - if proxyPing.Proxy.TLSRoutingEnabled { - var err error - proxyAddr, err = proxyPing.tlsRoutingProxyPublicAddr() - if err != nil { - return trace.Wrap(err, "determining tls routing address") - } + proxyAddr, err := proxyPing.proxyWebAddr() + if err != nil { + return trace.Wrap(err, "determining proxy web addr") } proxyHost, proxyPort, err := utils.SplitHostPort(proxyAddr) diff --git a/lib/tbot/service_kubernetes_output.go b/lib/tbot/service_kubernetes_output.go index 02a126931bc47..a4b25b1cc98c2 100644 --- a/lib/tbot/service_kubernetes_output.go +++ b/lib/tbot/service_kubernetes_output.go @@ -426,7 +426,7 @@ func selectKubeConnectionMethod(proxyPong *proxyPingResponse) ( // Even if KubePublicAddr is specified, we still use the general // PublicAddr when using TLS routing. if proxyPong.Proxy.TLSRoutingEnabled { - addr, err := proxyPong.tlsRoutingProxyPublicAddr() + addr, err := proxyPong.proxyWebAddr() if err != nil { return "", "", trace.Wrap(err) } diff --git a/lib/tbot/service_ssh_multiplexer.go b/lib/tbot/service_ssh_multiplexer.go index ee17c8e41d2e3..c03f460b18ee0 100644 --- a/lib/tbot/service_ssh_multiplexer.go +++ b/lib/tbot/service_ssh_multiplexer.go @@ -273,13 +273,17 @@ func (s *SSHMultiplexerService) setup(ctx context.Context) ( if err != nil { return nil, nil, "", nil, trace.Wrap(err) } - proxyAddr := proxyPing.Proxy.SSH.PublicAddr + proxyAddr, err := proxyPing.proxyWebAddr() + if err != nil { + return nil, nil, "", nil, trace.Wrap(err, "determining proxy web addr") + } + proxyHost, _, err = net.SplitHostPort(proxyAddr) + if err != nil { + return nil, nil, "", nil, trace.Wrap(err) + } + connUpgradeRequired := false if proxyPing.Proxy.TLSRoutingEnabled { - proxyAddr, err = proxyPing.tlsRoutingProxyPublicAddr() - if err != nil { - return nil, nil, "", nil, trace.Wrap(err, "determining proxy address") - } connUpgradeRequired, err = s.alpnUpgradeCache.isUpgradeRequired( ctx, proxyAddr, s.botCfg.Insecure, ) @@ -287,10 +291,6 @@ func (s *SSHMultiplexerService) setup(ctx context.Context) ( return nil, nil, "", nil, trace.Wrap(err, "determining if ALPN upgrade is required") } } - proxyHost, _, err = net.SplitHostPort(proxyAddr) - if err != nil { - return nil, nil, "", nil, trace.Wrap(err) - } // Create Proxy and Auth clients proxyClient := newCyclingHostDialClient(100, proxyclient.ClientConfig{ diff --git a/lib/tbot/tbot.go b/lib/tbot/tbot.go index 4db192188e76d..2f5530a8a7647 100644 --- a/lib/tbot/tbot.go +++ b/lib/tbot/tbot.go @@ -764,22 +764,23 @@ type proxyPingResponse struct { // ProxyPing is incorrect. const shouldIgnoreProxyAddrEnv = "TBOT_IGNORE_PROXY_PING_ADDR" -// tlsRoutingProxyPublicAddr returns the public address of the proxy which -// should be used for TLS-routed connections. It takes into account the -// TBOT_IGNORE_PROXY_PING_ADDR env var which can be used to force the use of -// the proxy address explicitly provided by the user rather than that included -// in the ProxyPing. -func (p *proxyPingResponse) tlsRoutingProxyPublicAddr() (string, error) { - if os.Getenv(shouldIgnoreProxyAddrEnv) == "1" { +func shouldIgnoreProxyPingAddr() bool { + return os.Getenv(shouldIgnoreProxyAddrEnv) == "1" +} + +// proxyWebAddr returns the address to use to connect to the proxy web port. +// In TLS routing mode, this address should be used for most/all connections. +// This function takes into account the TBOT_IGNORE_PROXY_PING_ADDR environment +// variable, which can be used to force the use of the proxy address explicitly +// provided by the user rather than use the one fetched from the proxy ping. +func (p *proxyPingResponse) proxyWebAddr() (string, error) { + if shouldIgnoreProxyPingAddr() { if p.configuredProxyAddr == "" { return "", trace.BadParameter("TBOT_IGNORE_PROXY_PING_ADDR set but no explicit proxy address configured") } - if !p.Proxy.TLSRoutingEnabled { - return "", trace.BadParameter("TBOT_IGNORE_PROXY_PING_ADDR set but proxy does not have TLS routing enabled") - } return p.configuredProxyAddr, nil } - return p.Proxy.SSH.SSHPublicAddr, nil + return p.Proxy.SSH.PublicAddr, nil } type alpnProxyConnUpgradeRequiredCache struct {