Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --connect-to flag that allows to change dial target #905

Merged
merged 12 commits into from
Sep 12, 2024
11 changes: 11 additions & 0 deletions bind/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,17 @@ func DialConfig(fs *pflag.FlagSet, cfg *forwarder.DialConfig, prefix string) {
"With or without a timeout, the operating system may impose its own earlier timeout. For instance, TCP timeouts are often around 3 minutes. ")
}

func ConnectTo(fs *pflag.FlagSet, cfg *[]forwarder.HostPortPair) {
fs.Var(anyflag.NewSliceValue[forwarder.HostPortPair](*cfg, cfg, forwarder.ParseHostPortPair),
"connect-to", "<HOST1:PORT1:HOST2:PORT2>,..."+
"For a request to the given HOST1:PORT1 pair, connect to HOST2:PORT2 instead. "+
"This option is suitable to direct requests at a specific server, e.g. at a specific cluster node in a cluster of servers. "+
"This option is only used to establish the network connection and does not work when request is routed using an upstream proxy. "+
"It does NOT affect the hostname/port that is used for TLS/SSL (e.g. SNI, certificate verification) or for the application protocols. "+
"HOST1 and PORT1 may be the empty string, meaning any host/port. "+
"HOST2 and PORT2 may also be the empty string, meaning use the request's original host/port. ")
}

func TLSClientConfig(fs *pflag.FlagSet, cfg *forwarder.TLSClientConfig) {
fs.DurationVar(&cfg.HandshakeTimeout,
"http-tls-handshake-timeout", cfg.HandshakeTimeout,
Expand Down
1 change: 1 addition & 0 deletions command/forwarder/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func FlagGroups() templates.FlagGroups {
Prefix: []string{
"http",
"cacert-file",
"connect-to",
"insecure",
},
},
Expand Down
6 changes: 6 additions & 0 deletions command/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type command struct {
promReg *prometheus.Registry
dnsConfig *forwarder.DNSConfig
httpTransportConfig *forwarder.HTTPTransportConfig
connectTo []forwarder.HostPortPair
pac *url.URL
credentials []*forwarder.HostPortUser
denyDomains []ruleset.RegexpListItem
Expand Down Expand Up @@ -137,6 +138,10 @@ func (c *command) runE(cmd *cobra.Command, _ []string) (cmdErr error) {
logger.Infof("using TLS key logging, writing to %s", c.httpTransportConfig.TLSClientConfig.KeyLogFile)
}

if len(c.connectTo) > 0 {
c.httpTransportConfig.RedirectFunc = forwarder.DialRedirectFromHostPortPairs(c.connectTo)
}

var pr forwarder.PACResolver
if c.pac != nil {
// Disable metrics for receiving PAC file.
Expand Down Expand Up @@ -408,6 +413,7 @@ func Command() *cobra.Command {
fs := cmd.Flags()
bind.DNSConfig(fs, c.dnsConfig)
bind.HTTPTransportConfig(fs, c.httpTransportConfig)
bind.ConnectTo(fs, &c.connectTo)
bind.PAC(fs, &c.pac)
bind.Credentials(fs, &c.credentials)
bind.DenyDomains(fs, &c.denyDomains)
Expand Down
33 changes: 0 additions & 33 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,39 +61,6 @@ func wildcardPortTo0(val string) string {
return strings.Join(s, ":")
}

// ParseHostPortUser parses a user:password@host:port string into HostUser.
func ParseHostPortUser(val string) (*HostPortUser, error) {
if val == "" || !strings.Contains(val, "@") {
return nil, errors.New("expected user[:password]@host:port")
}

idx := strings.LastIndex(val, "@")

up := val[:idx]
hp := val[idx:]

ui, err := ParseUserinfo(up)
if err != nil {
return nil, err
}

u, err := url.Parse("http://" + wildcardPortTo0(hp))
if err != nil {
return nil, err
}

hpi := &HostPortUser{
Host: u.Hostname(),
Port: u.Port(),
Userinfo: ui,
}
if err := hpi.Validate(); err != nil {
return nil, err
}

return hpi, nil
}

func ParseProxyURL(val string) (*url.URL, error) {
scheme, hpu, ok := strings.Cut(val, "://")
if !ok {
Expand Down
56 changes: 0 additions & 56 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,62 +13,6 @@ import (
"testing"
)

func TestParseHostPortUser(t *testing.T) {
tests := []struct {
name string
input string
err string
}{
{
name: "normal",
input: "user:pass@foo:80",
},
{
name: "no user",
input: ":pass@foo:80",
err: "username cannot be empty",
},
{
name: "empty",
input: "",
err: "expected user[:password]@host:port",
},
{
name: "colon in password",
input: "user:pass:pass@foo:80",
},
{
name: "@ in password",
input: "user:p@ss@foo:80",
},
{
name: "@ in username",
input: "user@:pass@foo:80",
},
}

for i := range tests {
tc := &tests[i]
t.Run(tc.name, func(t *testing.T) {
hpi, err := ParseHostPortUser(tc.input)
if tc.err == "" {
if err != nil {
t.Fatalf("expected success, got %q", err)
}
pass, ok := hpi.Password()
if ok {
pass = ":" + pass
}
if hpi.Username()+pass+"@"+hpi.Host+":"+hpi.Port != tc.input {
t.Errorf("expected %q, got %q", tc.input, hpi.String())
}
} else if !strings.Contains(err.Error(), tc.err) {
t.Fatalf("expected error to contain %q, got %q", tc.err, err)
}
})
}
}

func TestParseUserinfo(t *testing.T) {
tests := []struct {
name string
Expand Down
54 changes: 0 additions & 54 deletions credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,60 +15,6 @@ import (
"github.com/saucelabs/forwarder/log"
)

type HostPortUser struct {
Host string
Port string
*url.Userinfo
}

func (hpu *HostPortUser) Validate() error {
if hpu.Host == "" {
return errors.New("missing host")
}
if hpu.Port == "" {
return errors.New("missing port")
}
if hpu.Userinfo == nil {
return errors.New("missing user")
}
return validatedUserInfo(hpu.Userinfo)
}

func (hpu *HostPortUser) String() string {
if hpu == nil {
return ""
}

port := hpu.Port
if port == "0" {
port = "*"
}

p, ok := hpu.Password()
if !ok {
return fmt.Sprintf("%s@%s:%s", hpu.Username(), hpu.Host, port)
}

return fmt.Sprintf("%s:%s@%s:%s", hpu.Username(), p, hpu.Host, port)
}

func RedactHostPortUser(hpu *HostPortUser) string {
if hpu == nil {
return ""
}

port := hpu.Port
if port == "0" {
port = "*"
}

if _, ok := hpu.Password(); !ok {
return fmt.Sprintf("%s@%s:%s", hpu.Username(), hpu.Host, port)
}

return fmt.Sprintf("%s:xxxxx@%s:%s", hpu.Username(), hpu.Host, port)
}

type CredentialsMatcher struct {
hostport map[string]*url.Userinfo
host map[string]*url.Userinfo
Expand Down
5 changes: 5 additions & 0 deletions e2e/forwarder/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ func (s *Service) WithDNSTimeout(timeout time.Duration) *Service {
return s
}

func (s *Service) WithConnectTo(v string) compose.ServiceBuilder {
s.Environment["FORWARDER_CONNECT_TO"] = v
return s
}

func (s *Service) WithHTTPDialTimeout(timeout time.Duration) *Service {
s.Environment["FORWARDER_HTTP_DIAL_TIMEOUT"] = timeout.String()
return s
Expand Down
7 changes: 5 additions & 2 deletions e2e/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,16 @@ func main() {
if *args.debug {
if strings.HasPrefix(srv.Image, "saucelabs/forwarder") {
srv.Environment["FORWARDER_LOG_LEVEL"] = "debug"
srv.Environment["FORWARDER_LOG_HTTP"] = "headers,api:errors"
}
if srv.Name == forwarder.ProxyServiceName {
switch srv.Name {
case forwarder.ProxyServiceName:
srv.Environment["FORWARDER_LOG_HTTP"] = "headers,api:errors"
srv.Ports = append(srv.Ports,
"3128:3128",
"10000:10000",
)
case forwarder.HttpbinServiceName:
srv.Environment["FORWARDER_LOG_HTTP"] = "headers"
}
}
}
Expand Down
17 changes: 17 additions & 0 deletions e2e/setups.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func AllSetups() []setup.Setup {
SetupFlagResponseHeader(l)
SetupFlagConnectHeader(l)
SetupFlagDNSServer(l)
SetupFlagConnectTo(l)
SetupFlagInsecure(l)
SetupFlagMITMCACert(l)
SetupFlagMITMGenCA(l)
Expand Down Expand Up @@ -294,6 +295,22 @@ func SetupFlagDNSServer(l *setupList) {
}
}

func SetupFlagConnectTo(l *setupList) {
l.Add(
setup.Setup{
Name: "flag-connect-to",
Compose: compose.NewBuilder().
AddService(
forwarder.HttpbinService()).
AddService(
forwarder.ProxyService().
WithConnectTo("foo::httpbin:8080")).
MustBuild(),
Run: "^TestFlagConnectTo$",
},
)
}

func SetupFlagInsecure(l *setupList) {
l.Add(
setup.Setup{
Expand Down
4 changes: 4 additions & 0 deletions e2e/tests/flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ func TestFlagDNSServer(t *testing.T) {
})
}

func TestFlagConnectTo(t *testing.T) {
newClient(t, serviceScheme("HTTPBIN_PROTOCOL")+"://foo:123").GET("/status/200").ExpectStatus(http.StatusOK)
}

func TestFlagInsecure(t *testing.T) {
t.Run("true", func(t *testing.T) {
newClient(t, httpbin).GET("/status/200").ExpectStatus(http.StatusOK)
Expand Down
Loading