Skip to content

Commit

Permalink
WIP: partial flush-interval implementation for louketo#645
Browse files Browse the repository at this point in the history
TODO:
- SIGSEGV spotted when using 2s default and `make test`
- Unsure about Apache 2 & BSD-3 license compatibility as used
  • Loading branch information
Beanow committed Aug 15, 2020
1 parent 981ca45 commit a88318a
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 0 deletions.
169 changes: 169 additions & 0 deletions caddy/streaming.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Most of the code in this file was initially borrowed from the Go
// standard library and modified; It had this copyright notice:
// Copyright 2011 The Go Authors

// This code was largely taken from Caddy (https://github.com/caddyserver/caddy)
// at git hash e385be922569c07a0471a6798d4aeaf972facb5b.
// The above copyright notice is what it originally had.

package caddy

import (
"io"
"net/http"
"sync"
"time"
)

type FlushResponseWriter interface {
http.ResponseWriter
http.Flusher
}

func WrapFlushResponseWriter(rw http.ResponseWriter, flushInterval time.Duration) http.ResponseWriter {
mlw := maxLatencyWriter{
dst: rw.(writeFlusher),
latency: flushInterval,
}

return &maxLatencyResponseWriter{rw, mlw}
}

// copyBuffer returns any write errors or non-EOF read errors, and the amount
// of bytes written.
func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, 32*1024)
}
var written int64
for {
nr, rerr := src.Read(buf)
if nr > 0 {
nw, werr := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
}
if werr != nil {
return written, werr
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if rerr != nil {
if rerr == io.EOF {
rerr = nil
}
return written, rerr
}
}
}

type writeFlusher interface {
io.Writer
http.Flusher
}

type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration // non-zero; negative means to flush immediately

mu sync.Mutex // protects t, flushPending, and dst.Flush
t *time.Timer
flushPending bool
}

// Implement io.ReaderFrom, this will allow goproxy to io.Copy to it
// with the flushing mechanics in place.
func (m *maxLatencyWriter) ReadFrom(src io.Reader) (n int64, err error) {
defer m.stop()

buf := streamingBufPool.Get().([]byte)
defer streamingBufPool.Put(buf)

return copyBuffer(m, src, buf)
}

func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
n, err = m.dst.Write(p)
if m.latency < 0 {
m.dst.Flush()
return
}
if m.flushPending {
return
}
if m.t == nil {
m.t = time.AfterFunc(m.latency, m.delayedFlush)
} else {
m.t.Reset(m.latency)
}
m.flushPending = true
return
}

func (m *maxLatencyWriter) delayedFlush() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
return
}
m.dst.Flush()
m.flushPending = false
}

func (m *maxLatencyWriter) stop() {
m.mu.Lock()
defer m.mu.Unlock()
m.flushPending = false
if m.t != nil {
m.t.Stop()
}
}

// Passes normal response writer calls to the original.
// But handles writes / copies using maxLatencyWriter.
type maxLatencyResponseWriter struct {
rw http.ResponseWriter
mlw maxLatencyWriter
}

func (f *maxLatencyResponseWriter) Header() http.Header {
return f.rw.Header()
}

func (f *maxLatencyResponseWriter) WriteHeader(status int) {
f.rw.WriteHeader(status)
}

func (f *maxLatencyResponseWriter) Write(buf []byte) (int, error) {
return f.mlw.Write(buf)
}

// Implement io.ReaderFrom, this will allow goproxy to io.Copy to it
// with the flushing mechanics in place.
func (f *maxLatencyResponseWriter) ReadFrom(src io.Reader) (int64, error) {
return f.mlw.ReadFrom(src)
}

var streamingBufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func newDefaultConfig() *Config {
UpstreamTLSHandshakeTimeout: 10 * time.Second,
UpstreamTimeout: 10 * time.Second,
UseLetsEncrypt: false,
FlushInterval: 0,
}
}

Expand Down
3 changes: 3 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,6 @@ cors-methods: []
cors-credentials: true|false
# the max age (Access-Control-Max-Age)
cors-max-age: 1h

# duration interval to flush reverse proxy body chunks
flush-interval: 0s
3 changes: 3 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,9 @@ type Config struct {

// DisableAllLogging indicates no logging at all
DisableAllLogging bool `json:"disable-all-logging" yaml:"disable-all-logging" usage:"disables all logging to stdout and stderr"`

// FlushInterval is the maximum latency before we should flush proxied response bodies
FlushInterval time.Duration `json:"flush-interval" yaml:"flush-interval" usage:"the rate at which to flush response body, negative value is immedate, 0 value leaves default behaviour"`
}

// getVersion returns the proxy version
Expand Down
39 changes: 39 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"sync"
"time"

"github.com/louketo/louketo-proxy/caddy"

uuid "github.com/gofrs/uuid"

"github.com/PuerkitoBio/purell"
Expand Down Expand Up @@ -63,6 +65,22 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

// In order to send chunked responses, we must implement http.Flusher
func (w *gzipResponseWriter) Flush() {
gz := w.Writer.(*gzip.Writer)
f, ok := w.ResponseWriter.(http.Flusher)

if !ok {
panic(fmt.Sprintf("Oh dear, we flushed a ResponseWriter that can't flush: %T", w.ResponseWriter))
}

err := gz.Flush()
if err != nil {
panic(err)
}
f.Flush()
}

// gzipMiddleware is responsible for compressing a response
func gzipMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -540,6 +558,27 @@ func (r *oauthProxy) securityMiddleware(next http.Handler) http.Handler {
})
}

// flushIntervalMiddleware adds a Caddy-style control over `Transfer-Encoding: chunked` behaviour
func (r *oauthProxy) flushIntervalMiddleware(flushInterval time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Setting flushInterval to 0, means to keep unchanged flushing behaviour.
if flushInterval != 0 {
// We require the destination to implement writer and flusher.
if _, ok := w.(http.Flusher); ok {
w = caddy.WrapFlushResponseWriter(w, flushInterval)
} else {
r.log.Error("Unable to respect flush interval, provided ResponseWriter does not implement Flusher",
zap.String("type", fmt.Sprintf("%T", w)),
)
}
}

next.ServeHTTP(w, req)
})
}
}

// proxyDenyMiddleware just block everything
func proxyDenyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
Expand Down
3 changes: 3 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ func createLogger(config *Config) (*zap.Logger, error) {

// createReverseProxy creates a reverse proxy
func (r *oauthProxy) createReverseProxy() error {

r.log.Info("enabled reverse proxy mode, upstream url", zap.String("url", r.config.Upstream))
if err := r.createUpstreamProxy(r.endpoint); err != nil {
return err
Expand All @@ -184,6 +185,8 @@ func (r *oauthProxy) createReverseProxy() error {
engine.Use(r.securityMiddleware)
}

engine.Use(r.flushIntervalMiddleware(r.config.FlushInterval))

if len(r.config.CorsOrigins) > 0 {
c := cors.New(cors.Options{
AllowedOrigins: r.config.CorsOrigins,
Expand Down

0 comments on commit a88318a

Please sign in to comment.