From 45e03e2246fd727abf041fecfabec4659fda4e36 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Else Date: Sat, 22 Jun 2024 16:17:29 +0200 Subject: [PATCH] Revert "Restructure and test http compress package" This reverts commit ebad6ef45881664a684be0f67915f0907611162b. --- pkgs/httpcompress/httpCompress.go | 166 ++++++++++++------------- pkgs/httpcompress/httpCompress_test.go | 159 ----------------------- 2 files changed, 83 insertions(+), 242 deletions(-) delete mode 100644 pkgs/httpcompress/httpCompress_test.go diff --git a/pkgs/httpcompress/httpCompress.go b/pkgs/httpcompress/httpCompress.go index 8ed3726..5a24b01 100644 --- a/pkgs/httpcompress/httpCompress.go +++ b/pkgs/httpcompress/httpCompress.go @@ -2,7 +2,6 @@ package httpcompress import ( "bufio" - "cmp" "errors" "io" "net" @@ -58,12 +57,10 @@ type Compressor struct { func NewCompressor(types ...string) *Compressor { // If types are provided, set those as the allowed types. If none are // provided, use the default list. - if len(types) == 0 { - types = defaultCompressibleContentTypes - } - - // Build map based on types - allowedTypes := lo.SliceToMap(types, func(t string) (string, any) { return t, nil }) + allowedTypes := lo.SliceToMap( + lo.If(len(types) > 0, types).Else(defaultCompressibleContentTypes), + func(t string) (string, any) { return t, nil }, + ) c := &Compressor{ pooledEncoders: map[string]*sync.Pool{}, @@ -76,19 +73,6 @@ func NewCompressor(types ...string) *Compressor { return c } -// Interface for types that allow resetting io.Writers. -type compressWriter interface { - io.Writer - Reset(w io.Writer) - Flush() error -} - -// An EncoderFunc is a function that wraps the provided io.Writer with a -// streaming compression algorithm and returns it. -// -// In case of failure, the function should return nil. -type EncoderFunc func(w io.Writer) compressWriter - // SetEncoder can be used to set the implementation of a compression algorithm. // // The encoding should be a standardised identifier. See: @@ -112,54 +96,17 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) { return fn(io.Discard) }, } - c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...) } -type compressResponseWriter struct { - http.ResponseWriter // The response writer to delegate to. - encoding string // The accepted encoding. - encoder compressWriter // The encoder to use. - cleanup func() // Cleanup function to reset and repool encoder. - compressor *Compressor // Holds the compressor configuration. - wroteHeader bool // Whether the header has been written. -} - -func (c *Compressor) findAcceptedEncoding(r *http.Request) string { - accepted := strings.Split(strings.ToLower(strings.ReplaceAll(r.Header.Get("Accept-Encoding"), " ", "")), ",") - for _, name := range c.encodingPrecedence { - if slices.Contains(accepted, name) { - // We found accepted encoding - if _, ok := c.pooledEncoders[name]; ok { - // And it also exists a pool for the encoder, we can use it - return name - } - } - } - return "" -} - -func (cw *compressResponseWriter) doCleanup() { - if cw.cleanup != nil { - cw.cleanup() - cw.cleanup = nil - } -} - // Handler returns a new middleware that will compress the response based on the // current Compressor. func (c *Compressor) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - encoding := c.findAcceptedEncoding(r) - if encoding == "" { - // No encoding accepted, serve directly - next.ServeHTTP(w, r) - return - } cw := &compressResponseWriter{ - encoding: encoding, compressor: c, ResponseWriter: w, + request: r, } next.ServeHTTP(cw, r) _ = cw.Close() @@ -167,54 +114,107 @@ func (c *Compressor) Handler(next http.Handler) http.Handler { }) } +// An EncoderFunc is a function that wraps the provided io.Writer with a +// streaming compression algorithm and returns it. +// +// In case of failure, the function should return nil. +type EncoderFunc func(w io.Writer) compressWriter + +// Interface for types that allow resetting io.Writers. +type compressWriter interface { + io.Writer + Reset(w io.Writer) + Flush() error +} + +type compressResponseWriter struct { + http.ResponseWriter // The response writer to delegate to. + encoder compressWriter // The encoder to use (if any). + cleanup func() // Cleanup function to reset and repool encoder. + compressor *Compressor // Holds the compressor configuration. + request *http.Request // The request that is being handled. + wroteHeader bool // Whether the header has been written. +} + func (cw *compressResponseWriter) isCompressable() bool { - _, ok := cw.compressor.allowedTypes[strings.SplitN(cw.Header().Get("Content-Type"), ";", 2)[0]] + // Parse the first part of the Content-Type response header. + contentType := cw.Header().Get("Content-Type") + if idx := strings.Index(contentType, ";"); idx >= 0 { + contentType = contentType[0:idx] + } + + // Is the content type compressable? + _, ok := cw.compressor.allowedTypes[contentType] return ok } -func (cw *compressResponseWriter) enableEncoder() { - pool := cw.compressor.pooledEncoders[cw.encoding] - cw.encoder = pool.Get().(compressWriter) - if cw.encoder == nil { - return +func (cw *compressResponseWriter) writer() io.Writer { + if cw.encoder != nil { + return cw.encoder } - cw.cleanup = func() { - encoder := cw.encoder + return cw.ResponseWriter +} + +// selectEncoder returns the encoder, the name of the encoder, and a closer function. +func (cw *compressResponseWriter) selectEncoder() (compressWriter, string, func()) { + // Parse the names of all accepted algorithms from the header. + accepted := strings.Split(strings.ToLower(strings.ReplaceAll(cw.request.Header.Get("Accept-Encoding"), " ", "")), ",") + + // Find supported encoder by accepted list by precedence + for _, name := range cw.compressor.encodingPrecedence { + if slices.Contains(accepted, name) { + if pool, ok := cw.compressor.pooledEncoders[name]; ok { + encoder := pool.Get().(compressWriter) + cleanup := func() { + encoder.Reset(nil) + pool.Put(encoder) + } + encoder.Reset(cw.ResponseWriter) + return encoder, name, cleanup + } + } + } + + // No encoder found to match the accepted encoding + return nil, "", nil +} + +func (cw *compressResponseWriter) doCleanup() { + if cw.encoder != nil { cw.encoder = nil - encoder.Reset(nil) - pool.Put(encoder) + cw.cleanup() + cw.cleanup = nil } - cw.encoder.Reset(cw.ResponseWriter) } func (cw *compressResponseWriter) WriteHeader(code int) { + defer cw.ResponseWriter.WriteHeader(code) + if cw.wroteHeader { return } - defer cw.ResponseWriter.WriteHeader(code) cw.wroteHeader = true - if cw.Header().Get("Content-Encoding") != "" || !cw.isCompressable() { - // Data has already been compressed or is not compressable. + if cw.Header().Get("Content-Encoding") != "" { + // Data has already been compressed. return } - // Enable encoding - cw.enableEncoder() - if cw.encoder == nil { + if !cw.isCompressable() { + // Data is not compressable. return } - cw.Header().Set("Content-Encoding", cw.encoding) - cw.Header().Add("Vary", "Accept-Encoding") - - // The content-length after compression is unknown - cw.Header().Del("Content-Length") -} + var encoding string + cw.encoder, encoding, cw.cleanup = cw.selectEncoder() + if encoding != "" { + cw.Header().Set("Content-Encoding", encoding) + cw.Header().Add("Vary", "Accept-Encoding") -func (cw *compressResponseWriter) writer() io.Writer { - return cmp.Or[io.Writer](cw.encoder, cw.ResponseWriter) + // The content-length after compression is unknown + cw.Header().Del("Content-Length") + } } func (cw *compressResponseWriter) Write(p []byte) (int, error) { diff --git a/pkgs/httpcompress/httpCompress_test.go b/pkgs/httpcompress/httpCompress_test.go deleted file mode 100644 index 4ec77fa..0000000 --- a/pkgs/httpcompress/httpCompress_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package httpcompress - -import ( - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/go-chi/chi/v5" - "github.com/klauspost/compress/gzip" - "github.com/klauspost/compress/zstd" -) - -func TestCompressor(t *testing.T) { - r := chi.NewRouter() - - compressor := NewCompressor("text/html", "text/css") - if len(compressor.pooledEncoders) != 2 { - t.Errorf("gzip and zstd should be pooled") - } - - r.Use(compressor.Handler) - - r.Get("/gethtml", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - w.Write([]byte("textstring")) - }) - - r.Get("/getjpeg", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "images/jpeg") - w.Write([]byte("textstring")) - }) - - ts := httptest.NewServer(r) - defer ts.Close() - - tests := []struct { - name string - path string - expectedEncoding string - acceptedEncodings string - }{ - { - name: "no expected encodings due to no accepted encodings", - path: "/gethtml", - acceptedEncodings: "", - expectedEncoding: "", - }, - { - name: "no expected encodings due to content type", - path: "/getjpeg", - acceptedEncodings: "", - expectedEncoding: "", - }, - { - name: "gzip is only encoding", - path: "/gethtml", - acceptedEncodings: "gzip", - expectedEncoding: "gzip", - }, - { - name: "zstd is only encoding", - path: "/gethtml", - acceptedEncodings: "zstd", - expectedEncoding: "zstd", - }, - { - name: "deflate is only encoding", - path: "/gethtml", - acceptedEncodings: "deflate", - expectedEncoding: "", - }, - { - name: "multiple encoding seperated with comma and space", - path: "/gethtml", - acceptedEncodings: "zstd, gzip, deflate", - expectedEncoding: "zstd", - }, - { - name: "multiple encoding seperated with comma and without space", - path: "/gethtml", - acceptedEncodings: "zstd,gzip,deflate", - expectedEncoding: "zstd", - }, - { - name: "multiple encoding", - path: "/gethtml", - acceptedEncodings: "gzip, zstd", - expectedEncoding: "zstd", - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings) - if respString != "textstring" { - t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString) - } - if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding { - t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got) - } - - }) - - } -} - -func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings string) (*http.Response, string) { - req, err := http.NewRequest(method, ts.URL+path, nil) - if err != nil { - t.Fatal(err) - return nil, "" - } - if encodings != "" { - req.Header.Set("Accept-Encoding", encodings) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - return nil, "" - } - - respBody := decodeResponseBody(t, resp) - defer resp.Body.Close() - - return resp, respBody -} - -func decodeResponseBody(t *testing.T, resp *http.Response) string { - var reader io.Reader - switch resp.Header.Get("Content-Encoding") { - case "gzip": - var err error - reader, err = gzip.NewReader(resp.Body) - if err != nil { - t.Fatal(err) - } - case "zstd": - var err error - reader, err = zstd.NewReader(resp.Body) - if err != nil { - t.Fatal(err) - } - default: - reader = resp.Body - } - respBody, err := io.ReadAll(reader) - if err != nil { - t.Fatal(err) - return "" - } - if closer, ok := reader.(io.ReadCloser); ok { - closer.Close() - } - - return string(respBody) -}