diff --git a/internal/martian/proxy_test.go b/internal/martian/proxy_test.go index ba3bc7da..f976580d 100644 --- a/internal/martian/proxy_test.go +++ b/internal/martian/proxy_test.go @@ -485,6 +485,77 @@ func TestIntegrationHTTP101SwitchingProtocols(t *testing.T) { } } +func TestIntegrationHTTP304NotModified(t *testing.T) { + t.Parallel() + + tm := martiantest.NewModifier() + h := testHelper{ + Proxy: func(p *Proxy) { + p.ReadTimeout = 200 * time.Millisecond + p.WriteTimeout = 200 * time.Millisecond + p.RequestModifier = tm + p.ResponseModifier = tm + p.AllowHTTP = true + }, + } + + sl, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("net.Listen(): got %v, want no error", err) + } + + go func() { + conn, err := sl.Accept() + if err != nil { + log.Errorf(context.TODO(), "proxy_test: failed to accept connection: %v", err) + return + } + defer conn.Close() + + log.Infof(context.TODO(), "proxy_test: accepted connection: %s", conn.RemoteAddr()) + + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + log.Errorf(context.TODO(), "proxy_test: failed to read request: %v", err) + return + } + + res := proxyutil.NewResponse(304, http.NoBody, req) + res.Header.Set("Content-Length", "13") + res.Write(conn) + log.Infof(context.TODO(), "proxy_test: sent 304 response") + + log.Infof(context.TODO(), "proxy_test: closed connection") + }() + + conn, cancel := h.proxyConn(t) + defer cancel() + defer conn.Close() + + host := sl.Addr().String() + + req, err := http.NewRequest(http.MethodGet, "http://"+host, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): got %v, want no error", err) + } + if err := req.WriteProxy(conn); err != nil { + t.Fatalf("req.WriteProxy(): got %v, want no error", err) + } + + res, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + t.Fatalf("http.ReadResponse(): got %v, want no error", err) + } + defer res.Body.Close() + + if got, want := res.StatusCode, 304; got != want { + t.Fatalf("res.StatusCode: got %d, want %d", got, want) + } + if _, err := conn.Read(make([]byte, 0)); err != nil { + t.Fatalf("conn.Read(): got %v, want no error", err) + } +} + func TestIntegrationUnexpectedUpstreamFailure(t *testing.T) { t.Parallel()