diff --git a/pkgs/httpcompress/httpCompress.go b/pkgs/httpcompress/httpCompress.go index 5a24b01..8ed3726 100644 --- a/pkgs/httpcompress/httpCompress.go +++ b/pkgs/httpcompress/httpCompress.go @@ -2,6 +2,7 @@ package httpcompress import ( "bufio" + "cmp" "errors" "io" "net" @@ -57,10 +58,12 @@ type Compressor struct { func NewCompressor(types ...string) *Compressor { // If types are provided, set those as the allowed types. If none are // provided, use the default list. - allowedTypes := lo.SliceToMap( - lo.If(len(types) > 0, types).Else(defaultCompressibleContentTypes), - func(t string) (string, any) { return t, nil }, - ) + if len(types) == 0 { + types = defaultCompressibleContentTypes + } + + // Build map based on types + allowedTypes := lo.SliceToMap(types, func(t string) (string, any) { return t, nil }) c := &Compressor{ pooledEncoders: map[string]*sync.Pool{}, @@ -73,6 +76,19 @@ func NewCompressor(types ...string) *Compressor { return c } +// Interface for types that allow resetting io.Writers. +type compressWriter interface { + io.Writer + Reset(w io.Writer) + Flush() error +} + +// An EncoderFunc is a function that wraps the provided io.Writer with a +// streaming compression algorithm and returns it. +// +// In case of failure, the function should return nil. +type EncoderFunc func(w io.Writer) compressWriter + // SetEncoder can be used to set the implementation of a compression algorithm. // // The encoding should be a standardised identifier. See: @@ -96,17 +112,54 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) { return fn(io.Discard) }, } + c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...) } +type compressResponseWriter struct { + http.ResponseWriter // The response writer to delegate to. + encoding string // The accepted encoding. + encoder compressWriter // The encoder to use. + cleanup func() // Cleanup function to reset and repool encoder. + compressor *Compressor // Holds the compressor configuration. + wroteHeader bool // Whether the header has been written. +} + +func (c *Compressor) findAcceptedEncoding(r *http.Request) string { + accepted := strings.Split(strings.ToLower(strings.ReplaceAll(r.Header.Get("Accept-Encoding"), " ", "")), ",") + for _, name := range c.encodingPrecedence { + if slices.Contains(accepted, name) { + // We found accepted encoding + if _, ok := c.pooledEncoders[name]; ok { + // And it also exists a pool for the encoder, we can use it + return name + } + } + } + return "" +} + +func (cw *compressResponseWriter) doCleanup() { + if cw.cleanup != nil { + cw.cleanup() + cw.cleanup = nil + } +} + // Handler returns a new middleware that will compress the response based on the // current Compressor. func (c *Compressor) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + encoding := c.findAcceptedEncoding(r) + if encoding == "" { + // No encoding accepted, serve directly + next.ServeHTTP(w, r) + return + } cw := &compressResponseWriter{ + encoding: encoding, compressor: c, ResponseWriter: w, - request: r, } next.ServeHTTP(cw, r) _ = cw.Close() @@ -114,107 +167,54 @@ func (c *Compressor) Handler(next http.Handler) http.Handler { }) } -// An EncoderFunc is a function that wraps the provided io.Writer with a -// streaming compression algorithm and returns it. -// -// In case of failure, the function should return nil. -type EncoderFunc func(w io.Writer) compressWriter - -// Interface for types that allow resetting io.Writers. -type compressWriter interface { - io.Writer - Reset(w io.Writer) - Flush() error -} - -type compressResponseWriter struct { - http.ResponseWriter // The response writer to delegate to. - encoder compressWriter // The encoder to use (if any). - cleanup func() // Cleanup function to reset and repool encoder. - compressor *Compressor // Holds the compressor configuration. - request *http.Request // The request that is being handled. - wroteHeader bool // Whether the header has been written. -} - func (cw *compressResponseWriter) isCompressable() bool { - // Parse the first part of the Content-Type response header. - contentType := cw.Header().Get("Content-Type") - if idx := strings.Index(contentType, ";"); idx >= 0 { - contentType = contentType[0:idx] - } - - // Is the content type compressable? - _, ok := cw.compressor.allowedTypes[contentType] + _, ok := cw.compressor.allowedTypes[strings.SplitN(cw.Header().Get("Content-Type"), ";", 2)[0]] return ok } -func (cw *compressResponseWriter) writer() io.Writer { - if cw.encoder != nil { - return cw.encoder - } - return cw.ResponseWriter -} - -// selectEncoder returns the encoder, the name of the encoder, and a closer function. -func (cw *compressResponseWriter) selectEncoder() (compressWriter, string, func()) { - // Parse the names of all accepted algorithms from the header. - accepted := strings.Split(strings.ToLower(strings.ReplaceAll(cw.request.Header.Get("Accept-Encoding"), " ", "")), ",") - - // Find supported encoder by accepted list by precedence - for _, name := range cw.compressor.encodingPrecedence { - if slices.Contains(accepted, name) { - if pool, ok := cw.compressor.pooledEncoders[name]; ok { - encoder := pool.Get().(compressWriter) - cleanup := func() { - encoder.Reset(nil) - pool.Put(encoder) - } - encoder.Reset(cw.ResponseWriter) - return encoder, name, cleanup - } - } +func (cw *compressResponseWriter) enableEncoder() { + pool := cw.compressor.pooledEncoders[cw.encoding] + cw.encoder = pool.Get().(compressWriter) + if cw.encoder == nil { + return } - - // No encoder found to match the accepted encoding - return nil, "", nil -} - -func (cw *compressResponseWriter) doCleanup() { - if cw.encoder != nil { + cw.cleanup = func() { + encoder := cw.encoder cw.encoder = nil - cw.cleanup() - cw.cleanup = nil + encoder.Reset(nil) + pool.Put(encoder) } + cw.encoder.Reset(cw.ResponseWriter) } func (cw *compressResponseWriter) WriteHeader(code int) { - defer cw.ResponseWriter.WriteHeader(code) - if cw.wroteHeader { return } + defer cw.ResponseWriter.WriteHeader(code) cw.wroteHeader = true - if cw.Header().Get("Content-Encoding") != "" { - // Data has already been compressed. + if cw.Header().Get("Content-Encoding") != "" || !cw.isCompressable() { + // Data has already been compressed or is not compressable. return } - if !cw.isCompressable() { - // Data is not compressable. + // Enable encoding + cw.enableEncoder() + if cw.encoder == nil { return } - var encoding string - cw.encoder, encoding, cw.cleanup = cw.selectEncoder() - if encoding != "" { - cw.Header().Set("Content-Encoding", encoding) - cw.Header().Add("Vary", "Accept-Encoding") + cw.Header().Set("Content-Encoding", cw.encoding) + cw.Header().Add("Vary", "Accept-Encoding") - // The content-length after compression is unknown - cw.Header().Del("Content-Length") - } + // The content-length after compression is unknown + cw.Header().Del("Content-Length") +} + +func (cw *compressResponseWriter) writer() io.Writer { + return cmp.Or[io.Writer](cw.encoder, cw.ResponseWriter) } func (cw *compressResponseWriter) Write(p []byte) (int, error) { diff --git a/pkgs/httpcompress/httpCompress_test.go b/pkgs/httpcompress/httpCompress_test.go new file mode 100644 index 0000000..4ec77fa --- /dev/null +++ b/pkgs/httpcompress/httpCompress_test.go @@ -0,0 +1,159 @@ +package httpcompress + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/klauspost/compress/gzip" + "github.com/klauspost/compress/zstd" +) + +func TestCompressor(t *testing.T) { + r := chi.NewRouter() + + compressor := NewCompressor("text/html", "text/css") + if len(compressor.pooledEncoders) != 2 { + t.Errorf("gzip and zstd should be pooled") + } + + r.Use(compressor.Handler) + + r.Get("/gethtml", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte("textstring")) + }) + + r.Get("/getjpeg", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "images/jpeg") + w.Write([]byte("textstring")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + tests := []struct { + name string + path string + expectedEncoding string + acceptedEncodings string + }{ + { + name: "no expected encodings due to no accepted encodings", + path: "/gethtml", + acceptedEncodings: "", + expectedEncoding: "", + }, + { + name: "no expected encodings due to content type", + path: "/getjpeg", + acceptedEncodings: "", + expectedEncoding: "", + }, + { + name: "gzip is only encoding", + path: "/gethtml", + acceptedEncodings: "gzip", + expectedEncoding: "gzip", + }, + { + name: "zstd is only encoding", + path: "/gethtml", + acceptedEncodings: "zstd", + expectedEncoding: "zstd", + }, + { + name: "deflate is only encoding", + path: "/gethtml", + acceptedEncodings: "deflate", + expectedEncoding: "", + }, + { + name: "multiple encoding seperated with comma and space", + path: "/gethtml", + acceptedEncodings: "zstd, gzip, deflate", + expectedEncoding: "zstd", + }, + { + name: "multiple encoding seperated with comma and without space", + path: "/gethtml", + acceptedEncodings: "zstd,gzip,deflate", + expectedEncoding: "zstd", + }, + { + name: "multiple encoding", + path: "/gethtml", + acceptedEncodings: "gzip, zstd", + expectedEncoding: "zstd", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings) + if respString != "textstring" { + t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString) + } + if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding { + t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got) + } + + }) + + } +} + +func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings string) (*http.Response, string) { + req, err := http.NewRequest(method, ts.URL+path, nil) + if err != nil { + t.Fatal(err) + return nil, "" + } + if encodings != "" { + req.Header.Set("Accept-Encoding", encodings) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + return nil, "" + } + + respBody := decodeResponseBody(t, resp) + defer resp.Body.Close() + + return resp, respBody +} + +func decodeResponseBody(t *testing.T, resp *http.Response) string { + var reader io.Reader + switch resp.Header.Get("Content-Encoding") { + case "gzip": + var err error + reader, err = gzip.NewReader(resp.Body) + if err != nil { + t.Fatal(err) + } + case "zstd": + var err error + reader, err = zstd.NewReader(resp.Body) + if err != nil { + t.Fatal(err) + } + default: + reader = resp.Body + } + respBody, err := io.ReadAll(reader) + if err != nil { + t.Fatal(err) + return "" + } + if closer, ok := reader.(io.ReadCloser); ok { + closer.Close() + } + + return string(respBody) +}