Skip to content

Commit

Permalink
martian: add support for TLSHandshakeTimeout
Browse files Browse the repository at this point in the history
Martian Proxy will explicitly do the handshake if accepted connection is tls.Conn.
  • Loading branch information
Choraden committed Oct 15, 2024
1 parent d5e6021 commit d68795c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 3 deletions.
1 change: 1 addition & 0 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func (hp *HTTPProxy) configureProxy() error {
hp.proxy.WithoutWarning = true
hp.proxy.ErrorResponse = hp.errorResponse
hp.proxy.IdleTimeout = hp.config.IdleTimeout
hp.proxy.TLSHandshakeTimeout = hp.config.TLSServerConfig.HandshakeTimeout
hp.proxy.ReadTimeout = hp.config.ReadTimeout
hp.proxy.ReadHeaderTimeout = hp.config.ReadHeaderTimeout
hp.proxy.WriteTimeout = hp.config.WriteTimeout
Expand Down
11 changes: 10 additions & 1 deletion internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ type Proxy struct {
// If both are zero, there is no timeout.
IdleTimeout time.Duration

// TLSHandshakeTimeout is the maximum amount of time to wait for a TLS handshake.
// The proxy will try to cast accepted connections to tls.Conn and perform a handshake.
// If TLSHandshakeTimeout is zero, no timeout is set.
TLSHandshakeTimeout time.Duration

// ReadTimeout is the maximum duration for reading the entire
// request, including the body. A zero or negative value means
// there will be no timeout.
Expand Down Expand Up @@ -255,7 +260,11 @@ func (p *Proxy) handleLoop(conn net.Conn) {
return
}

pc := newProxyConn(p, conn)
pc, err := newProxyConn(p, conn)
if err != nil {
log.Errorf(context.TODO(), "failed to create proxy connection: %v", err)
return
}

const maxConsecutiveErrors = 5
errorsN := 0
Expand Down
13 changes: 11 additions & 2 deletions internal/martian/proxy_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,28 @@ type proxyConn struct {
cs tls.ConnectionState
}

func newProxyConn(p *Proxy, conn net.Conn) *proxyConn {
func newProxyConn(p *Proxy, conn net.Conn) (*proxyConn, error) {
v := &proxyConn{
Proxy: p,
brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
conn: conn,
}

if tconn, ok := conn.(*tls.Conn); ok {
ctx := context.Background()
if p.TLSHandshakeTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), p.TLSHandshakeTimeout)
defer cancel()
}
if err := tconn.HandshakeContext(ctx); err != nil {
return nil, fmt.Errorf("failed to do TLS handshake: %w", err)
}
v.secure = true
v.cs = tconn.ConnectionState()
}

return v
return v, nil
}

func (p *proxyConn) readRequest() (*http.Request, error) {
Expand Down
31 changes: 31 additions & 0 deletions internal/martian/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,37 @@ func TestIdleTimeout(t *testing.T) {
}
}

func TestTLSHandshakeTimeout(t *testing.T) {
t.Parallel()

l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
_, mc := certs(t)
l = tls.NewListener(l, mc.TLS(context.Background()))

h := testHelper{
Listener: l,
Proxy: func(p *Proxy) {
p.TLSHandshakeTimeout = 100 * time.Millisecond
},
}

c, cancel := h.proxyClient(t)
defer cancel()

conn, err := net.Dial("tcp", c.Addr)
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}

time.Sleep(200 * time.Millisecond)
if _, err := conn.Read(make([]byte, 1)); !errors.Is(err, io.EOF) {
t.Fatalf("conn.Read(): got %v, want io.EOF", err)
}
}

func TestReadHeaderTimeout(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit d68795c

Please sign in to comment.