Skip to content

Commit

Permalink
martian: add content length check in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatczuk committed Jul 25, 2024
1 parent c88a490 commit 50184fd
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions internal/martian/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,16 @@ func TestIntegrationConnectUpstreamProxy(t *testing.T) {
utr.Respond(299)
upstream.RoundTripper = utr

utm := martiantest.NewModifier()

// Force the CONNECT request to dial the local TLS server.
utm.RequestFunc(func(req *http.Request) {
if req.Method == http.MethodConnect && req.ContentLength != -1 {
t.Errorf("req.ContentLength: got %d, want -1", req.ContentLength)
}
})
upstream.RequestModifier = utm

ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
Expand Down Expand Up @@ -1017,6 +1027,9 @@ func TestIntegrationConnectUpstreamProxy(t *testing.T) {
if got, want := res.StatusCode, 200; got != want {
t.Fatalf("res.StatusCode: got %d, want %d", got, want)
}
if res.ContentLength != -1 {
t.Errorf("res.ContentLength: got %d, want -1", res.ContentLength)
}

roots := x509.NewCertPool()
roots.AddCert(ca)
Expand Down Expand Up @@ -1076,6 +1089,10 @@ func TestIntegrationConnectFunc(t *testing.T) {
l := newListener(t)
p := new(Proxy)
p.ConnectFunc = func(req *http.Request) (*http.Response, io.ReadWriteCloser, error) {
if req.ContentLength != -1 {
t.Errorf("req.ContentLength: got %d, want -1", req.ContentLength)
}

pr, pw := io.Pipe()
return newConnectResponse(req), pipeConn{pr, pw}, nil
}
Expand Down Expand Up @@ -1110,6 +1127,9 @@ func TestIntegrationConnectFunc(t *testing.T) {
if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}
if res.ContentLength != -1 {
t.Errorf("res.ContentLength: got %d, want -1", res.ContentLength)
}

if _, err := conn.Write([]byte("12345")); err != nil {
t.Fatalf("conn.Write(): got %v, want no error", err)
Expand Down

0 comments on commit 50184fd

Please sign in to comment.