Skip to content

Commit

Permalink
http proxy: allow multiple listeners
Browse files Browse the repository at this point in the history
Add HTTPProxyConfig.ExtraListeners and refactor code to work with multiple listeners.
  • Loading branch information
mmatczuk committed Nov 8, 2024
1 parent 2741f5f commit cd53cb9
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 57 deletions.
159 changes: 103 additions & 56 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"net/http"
"net/url"
"slices"
"sync"
"time"

"github.com/saucelabs/forwarder/hostsfile"
Expand All @@ -28,6 +27,8 @@ import (
"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/middleware"
"github.com/saucelabs/forwarder/pac"
"go.uber.org/multierr"
"golang.org/x/sync/errgroup"
)

type ProxyLocalhostMode string
Expand Down Expand Up @@ -80,6 +81,7 @@ var ErrConnectFallback = martian.ErrConnectFallback

type HTTPProxyConfig struct {
HTTPServerConfig
ExtraListeners []NamedListenerConfig
Name string
MITM *MITMConfig
MITMDomains Matcher
Expand Down Expand Up @@ -122,6 +124,11 @@ func (c *HTTPProxyConfig) Validate() error {
if err := c.HTTPServerConfig.Validate(); err != nil {
return err
}
for _, lc := range c.ExtraListeners {
if lc.Name == "" {
return errors.New("extra listener name is required")
}
}
if c.Protocol != HTTPScheme && c.Protocol != HTTPSScheme {
return fmt.Errorf("unsupported protocol: %s", c.Protocol)
}
Expand All @@ -148,7 +155,7 @@ type HTTPProxy struct {
localhost []string

tlsConfig *tls.Config
listener net.Listener
listeners []net.Listener
}

// NewHTTPProxy creates a new HTTP proxy.
Expand All @@ -171,13 +178,15 @@ func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher,
}
hp.localhost = append(hp.localhost, lh...)

l, err := hp.listen()
ll, err := hp.listen()
if err != nil {
return nil, err
}
hp.listener = l
hp.listeners = ll

hp.log.Infof("PROXY server listen address=%s protocol=%s", l.Addr(), hp.config.Protocol)
for _, l := range hp.listeners {
hp.log.Infof("PROXY server listen address=%s protocol=%s", l.Addr(), hp.config.Protocol)
}

return hp, nil
}
Expand Down Expand Up @@ -500,79 +509,117 @@ func (hp *HTTPProxy) handler() http.Handler {
}

func (hp *HTTPProxy) Run(ctx context.Context) error {
var srv *http.Server

var wg sync.WaitGroup
wg.Add(1)

go func() {
defer wg.Done()

<-ctx.Done()
if srv != nil {
if err := srv.Shutdown(context.Background()); err != nil {
hp.log.Errorf("failed to shutdown server error=%s", err)
}
} else {
hp.Close()
}
}()

var srvErr error
if hp.config.TestingHTTPHandler {
hp.log.Infof("using http handler")
srv = &http.Server{
Handler: hp.handler(),
IdleTimeout: hp.config.IdleTimeout,
ReadTimeout: hp.config.ReadTimeout,
ReadHeaderTimeout: hp.config.ReadHeaderTimeout,
WriteTimeout: hp.config.WriteTimeout,
}
srvErr = srv.Serve(hp.listener)
} else {
srvErr = hp.proxy.Serve(hp.listener)
return hp.runHTTPHandler(ctx)
}
return hp.run(ctx)
}

func (hp *HTTPProxy) runHTTPHandler(ctx context.Context) error {
srv := http.Server{
Handler: hp.handler(),
IdleTimeout: hp.config.IdleTimeout,
ReadTimeout: hp.config.ReadTimeout,
ReadHeaderTimeout: hp.config.ReadHeaderTimeout,
WriteTimeout: hp.config.WriteTimeout,
}
if srvErr != nil {
if errors.Is(srvErr, net.ErrClosed) {
srvErr = nil

var g errgroup.Group
g.Go(func() error {
<-ctx.Done()
if err := srv.Shutdown(context.Background()); err != nil {
hp.log.Errorf("failed to shutdown server error=%s", err)
}
return srvErr
return ctx.Err()
})
for i := range hp.listeners {
l := hp.listeners[i]
g.Go(func() error {
err := srv.Serve(l)
if errors.Is(err, http.ErrServerClosed) {
err = nil
}
return err
})
}
return g.Wait()
}

wg.Wait()
return nil
func (hp *HTTPProxy) run(ctx context.Context) error {
var g errgroup.Group
g.Go(func() error {
<-ctx.Done()
hp.Close()
return ctx.Err()
})
for i := range hp.listeners {
l := hp.listeners[i]
g.Go(func() error {
err := hp.proxy.Serve(l)
if errors.Is(err, net.ErrClosed) {
err = nil
}
return err
})
}
return g.Wait()
}

func (hp *HTTPProxy) listen() (net.Listener, error) {
func (hp *HTTPProxy) listen() ([]net.Listener, error) {
switch hp.config.Protocol {
case HTTPScheme, HTTPSScheme, HTTP2Scheme:
default:
return nil, fmt.Errorf("invalid protocol %q", hp.config.Protocol)
}

l := Listener{
ListenerConfig: hp.config.ListenerConfig,
TLSConfig: hp.tlsConfig,
PromConfig: PromConfig{
PromNamespace: hp.config.PromNamespace,
PromRegistry: hp.config.PromRegistry,
},
}

if err := l.Listen(); err != nil {
return nil, err
if len(hp.config.ExtraListeners) == 0 {
l := &Listener{
ListenerConfig: hp.config.ListenerConfig,
TLSConfig: hp.tlsConfig,
PromConfig: PromConfig{
PromNamespace: hp.config.PromNamespace,
PromRegistry: hp.config.PromRegistry,
},
}
if err := l.Listen(); err != nil {
return nil, err
}
return []net.Listener{l}, nil
}

return &l, nil
return MultiListener{
ListenerConfigs: append([]NamedListenerConfig{{ListenerConfig: hp.config.ListenerConfig}}, hp.config.ExtraListeners...),
TLSConfig: func(lc NamedListenerConfig) *tls.Config {
return hp.tlsConfig
},
PromConfig: hp.config.PromConfig,
}.Listen()
}

// Addr returns the address the server is listening on.
func (hp *HTTPProxy) Addr() string {
return hp.listener.Addr().String()
func (hp *HTTPProxy) Addr() (addrs []string, ok bool) {
addrs = make([]string, len(hp.listeners))
ok = true
for i, l := range hp.listeners {
addrs[i] = l.Addr().String()
if addrs[i] == "" {
ok = false
}
}
return
}

func (hp *HTTPProxy) Close() error {
err := hp.listener.Close()
// Close listeners first to prevent new connections.
var err error
for _, l := range hp.listeners {
if e := l.Close(); e != nil {
err = multierr.Append(err, e)
}
}

// Close the proxy to stop serving requests.
hp.proxy.Close()

if tr, ok := hp.transport.(*http.Transport); ok {
Expand Down
47 changes: 46 additions & 1 deletion net.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,48 @@ func DefaultListenerConfig(addr string) *ListenerConfig {
}
}

type NamedListenerConfig struct {
Name string
ListenerConfig
}

// MultiListener is a builder for multiple listeners sharing the same prometheus configuration.
type MultiListener struct {
ListenerConfigs []NamedListenerConfig
TLSConfig func(NamedListenerConfig) *tls.Config
PromConfig
}

func (ml MultiListener) Listen() (_ []net.Listener, ferr error) {
prototype := Listener{
metrics: newListenerMetrics(ml.PromRegistry, ml.PromNamespace),
}

listeners := make([]net.Listener, 0, len(ml.ListenerConfigs))
defer func() {
if ferr != nil {
for _, l := range listeners {
l.Close()
}
}
}()

for _, lc := range ml.ListenerConfigs {
l := new(Listener)
*l = prototype
l.ListenerConfig = lc.ListenerConfig
if ml.TLSConfig != nil {
l.TLSConfig = ml.TLSConfig(lc)
}
if err := l.Listen(); err != nil {
return nil, err
}
listeners = append(listeners, l)
}

return listeners, nil
}

type Listener struct {
ListenerConfig
TLSConfig *tls.Config
Expand Down Expand Up @@ -222,7 +264,10 @@ func (l *Listener) Listen() error {
}

l.listener = ll
l.metrics = newListenerMetrics(l.PromRegistry, l.PromNamespace)

if l.metrics == nil {
l.metrics = newListenerMetrics(l.PromRegistry, l.PromNamespace)
}

return nil
}
Expand Down

0 comments on commit cd53cb9

Please sign in to comment.