diff --git a/caddy/streaming.go b/caddy/streaming.go new file mode 100644 index 000000000..af60a589b --- /dev/null +++ b/caddy/streaming.go @@ -0,0 +1,169 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Most of the code in this file was initially borrowed from the Go +// standard library and modified; It had this copyright notice: +// Copyright 2011 The Go Authors + +// This code was largely taken from Caddy (https://github.com/caddyserver/caddy) +// at git hash e385be922569c07a0471a6798d4aeaf972facb5b. +// The above copyright notice is what it originally had. + +package caddy + +import ( + "io" + "net/http" + "sync" + "time" +) + +type FlushResponseWriter interface { + http.ResponseWriter + http.Flusher +} + +func WrapFlushResponseWriter(rw http.ResponseWriter, flushInterval time.Duration) http.ResponseWriter { + mlw := maxLatencyWriter{ + dst: rw.(writeFlusher), + latency: flushInterval, + } + + return &maxLatencyResponseWriter{rw, mlw} +} + +// copyBuffer returns any write errors or non-EOF read errors, and the amount +// of bytes written. +func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +// Implement io.ReaderFrom, this will allow goproxy to io.Copy to it +// with the flushing mechanics in place. +func (m *maxLatencyWriter) ReadFrom(src io.Reader) (n int64, err error) { + defer m.stop() + + buf := streamingBufPool.Get().([]byte) + defer streamingBufPool.Put(buf) + + return copyBuffer(m, src, buf) +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.dst.Flush() + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.dst.Flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +// Passes normal response writer calls to the original. +// But handles writes / copies using maxLatencyWriter. +type maxLatencyResponseWriter struct { + rw http.ResponseWriter + mlw maxLatencyWriter +} + +func (f *maxLatencyResponseWriter) Header() http.Header { + return f.rw.Header() +} + +func (f *maxLatencyResponseWriter) WriteHeader(status int) { + f.rw.WriteHeader(status) +} + +func (f *maxLatencyResponseWriter) Write(buf []byte) (int, error) { + return f.mlw.Write(buf) +} + +// Implement io.ReaderFrom, this will allow goproxy to io.Copy to it +// with the flushing mechanics in place. +func (f *maxLatencyResponseWriter) ReadFrom(src io.Reader) (int64, error) { + return f.mlw.ReadFrom(src) +} + +var streamingBufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 32*1024) + }, +} diff --git a/config.go b/config.go index 080af6ad3..01f672c9c 100644 --- a/config.go +++ b/config.go @@ -71,6 +71,7 @@ func newDefaultConfig() *Config { UpstreamTLSHandshakeTimeout: 10 * time.Second, UpstreamTimeout: 10 * time.Second, UseLetsEncrypt: false, + FlushInterval: 0, } } diff --git a/config_sample.yml b/config_sample.yml index 107ebe2b8..cef93693f 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -90,3 +90,6 @@ cors-methods: [] cors-credentials: true|false # the max age (Access-Control-Max-Age) cors-max-age: 1h + +# duration interval to flush reverse proxy body chunks +flush-interval: 0s diff --git a/doc.go b/doc.go index c42c53f0e..123cbed28 100644 --- a/doc.go +++ b/doc.go @@ -364,6 +364,9 @@ type Config struct { // DisableAllLogging indicates no logging at all DisableAllLogging bool `json:"disable-all-logging" yaml:"disable-all-logging" usage:"disables all logging to stdout and stderr"` + + // FlushInterval is the maximum latency before we should flush proxied response bodies + FlushInterval time.Duration `json:"flush-interval" yaml:"flush-interval" usage:"the rate at which to flush response body, negative value is immedate, 0 value leaves default behaviour"` } // getVersion returns the proxy version diff --git a/middleware.go b/middleware.go index 3d3bf5aca..c88665f98 100644 --- a/middleware.go +++ b/middleware.go @@ -27,6 +27,8 @@ import ( "sync" "time" + "github.com/louketo/louketo-proxy/caddy" + uuid "github.com/gofrs/uuid" "github.com/PuerkitoBio/purell" @@ -63,6 +65,22 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } +// In order to send chunked responses, we must implement http.Flusher +func (w *gzipResponseWriter) Flush() { + gz := w.Writer.(*gzip.Writer) + f, ok := w.ResponseWriter.(http.Flusher) + + if !ok { + panic(fmt.Sprintf("Oh dear, we flushed a ResponseWriter that can't flush: %T", w.ResponseWriter)) + } + + err := gz.Flush() + if err != nil { + panic(err) + } + f.Flush() +} + // gzipMiddleware is responsible for compressing a response func gzipMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { @@ -540,6 +558,27 @@ func (r *oauthProxy) securityMiddleware(next http.Handler) http.Handler { }) } +// flushIntervalMiddleware adds a Caddy-style control over `Transfer-Encoding: chunked` behaviour +func (r *oauthProxy) flushIntervalMiddleware(flushInterval time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Setting flushInterval to 0, means to keep unchanged flushing behaviour. + if flushInterval != 0 { + // We require the destination to implement writer and flusher. + if _, ok := w.(http.Flusher); ok { + w = caddy.WrapFlushResponseWriter(w, flushInterval) + } else { + r.log.Error("Unable to respect flush interval, provided ResponseWriter does not implement Flusher", + zap.String("type", fmt.Sprintf("%T", w)), + ) + } + } + + next.ServeHTTP(w, req) + }) + } +} + // proxyDenyMiddleware just block everything func proxyDenyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { diff --git a/server.go b/server.go index b2ae8bbf6..166a1124f 100644 --- a/server.go +++ b/server.go @@ -158,6 +158,7 @@ func createLogger(config *Config) (*zap.Logger, error) { // createReverseProxy creates a reverse proxy func (r *oauthProxy) createReverseProxy() error { + r.log.Info("enabled reverse proxy mode, upstream url", zap.String("url", r.config.Upstream)) if err := r.createUpstreamProxy(r.endpoint); err != nil { return err @@ -184,6 +185,8 @@ func (r *oauthProxy) createReverseProxy() error { engine.Use(r.securityMiddleware) } + engine.Use(r.flushIntervalMiddleware(r.config.FlushInterval)) + if len(r.config.CorsOrigins) > 0 { c := cors.New(cors.Options{ AllowedOrigins: r.config.CorsOrigins,