Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non blocking matchers & matching timeout #192

Merged
merged 14 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 94 additions & 45 deletions layer4/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
package layer4

import (
"bytes"
"context"
"io"
"errors"
"net"
"sync"

Expand All @@ -30,7 +29,7 @@ import (
// and variable table. This function is intended for use at the start of a
// connection handler chain where the underlying connection is not yet a layer4
// Connection value.
func WrapConnection(underlying net.Conn, buf *bytes.Buffer, logger *zap.Logger) *Connection {
func WrapConnection(underlying net.Conn, buf []byte, logger *zap.Logger) *Connection {
repl := caddy.NewReplacer()
repl.Set("l4.conn.remote_addr", underlying.RemoteAddr())
repl.Set("l4.conn.local_addr", underlying.LocalAddr())
Expand Down Expand Up @@ -66,51 +65,46 @@ type Connection struct {

Logger *zap.Logger

buf *bytes.Buffer // stores recordings
bufReader io.Reader // used to read buf so it doesn't discard bytes
recording bool
buf []byte // stores matching data
offset int
frozenOffset int
matching bool

bytesRead, bytesWritten uint64
}

var ErrConsumedAllPrefetchedBytes = errors.New("consumed all prefetched bytes")
var ErrMatchingBufferFull = errors.New("matching buffer is full")

// Read implements io.Reader in such a way that reads first
// deplete any associated buffer from the prior recording,
// and once depleted (or if there isn't one), it continues
// reading from the underlying connection.
func (cx *Connection) Read(p []byte) (n int, err error) {
// if we are matching and consumed the buffer exit with error
if cx.matching && (len(cx.buf) == 0 || len(cx.buf) == cx.offset) {
return 0, ErrConsumedAllPrefetchedBytes
}

// if there is a buffer we should read from, start
// with that; we only read from the underlying conn
// after the buffer has been "depleted"
if cx.bufReader != nil {
n, err = cx.bufReader.Read(p)
if err == io.EOF {
cx.bufReader = nil
err = nil
}
// prevent first read from returning 0 bytes because of empty bufReader
if !(n == 0 && err == nil) {
return
if len(cx.buf) > 0 && cx.offset < len(cx.buf) {
n := copy(p, cx.buf[cx.offset:])
cx.offset += n
if !cx.matching && cx.offset == len(cx.buf) {
// if we are not in matching mode reset buf automatically after it was consumed
cx.offset = 0
cx.buf = cx.buf[:0]
}
return n, nil
}

// buffer has been "depleted" so read from
// underlying connection
n, err = cx.Conn.Read(p)
cx.bytesRead += uint64(n)

if !cx.recording {
return
}

// since we're recording at this point, anything that
// was read needs to be written to the buffer, even
// if there was an error
if n > 0 {
if nw, errw := cx.buf.Write(p[:n]); errw != nil {
return nw, errw
}
}

return
}

Expand All @@ -130,33 +124,75 @@ func (cx *Connection) Wrap(conn net.Conn) *Connection {
Context: cx.Context,
Logger: cx.Logger,
buf: cx.buf,
bufReader: cx.bufReader,
recording: cx.recording,
offset: cx.offset,
matching: cx.matching,
bytesRead: cx.bytesRead,
bytesWritten: cx.bytesWritten,
}
}

// record starts recording the stream into cx.buf. It also creates a reader
// to read from the buffer but not to discard any byte.
func (cx *Connection) record() {
cx.recording = true
cx.bufReader = bytes.NewReader(cx.buf.Bytes()) // Don't discard bytes.
// prefetch tries to read all bytes that a client initially sent us without blocking.
func (cx *Connection) prefetch() (err error) {
var n int
var tmp []byte

for len(cx.buf) < MaxMatchingBytes {
free := cap(cx.buf) - len(cx.buf)
if free >= PrefetchChunkSize {
n, err = cx.Conn.Read(cx.buf[len(cx.buf) : len(cx.buf)+PrefetchChunkSize])
cx.buf = cx.buf[:len(cx.buf)+n]
} else {
if tmp == nil {
tmp = bufPool.Get().([]byte)
tmp = tmp[:PrefetchChunkSize]
defer bufPool.Put(tmp)
}
n, err = cx.Conn.Read(tmp)
cx.buf = append(cx.buf, tmp[:n]...)
}

cx.bytesRead += uint64(n)

if err != nil {
return err
}

if n < PrefetchChunkSize {
break
}
}

if cx.Logger.Core().Enabled(zap.DebugLevel) {
cx.Logger.Debug("prefetched",
zap.String("remote", cx.RemoteAddr().String()),
zap.Int("bytes", len(cx.buf)),
)
}

if len(cx.buf) >= MaxMatchingBytes {
return ErrMatchingBufferFull
}

return nil
}

// freeze activates the matching mode that only reads from cx.buf.
func (cx *Connection) freeze() {
cx.matching = true
cx.frozenOffset = cx.offset
}

// rewind stops recording and creates a reader for the
// buffer so that the next reads from an associated
// recordableConn come from the buffer first, then
// continue with the underlying conn.
func (cx *Connection) rewind() {
cx.recording = false
cx.bufReader = cx.buf // Actually consume bytes.
// unfreeze stops the matching mode and resets the buffer offset
// so that the next reads come from the buffer first.
func (cx *Connection) unfreeze() {
cx.matching = false
cx.offset = cx.frozenOffset
}

// SetVar sets a value in the context's variable table with
// the given key. It overwrites any previous value with the
// same key.
func (cx Connection) SetVar(key string, value interface{}) {
func (cx *Connection) SetVar(key string, value interface{}) {
varMap, ok := cx.Context.Value(VarsCtxKey).(map[string]interface{})
if !ok {
return
Expand All @@ -167,14 +203,20 @@ func (cx Connection) SetVar(key string, value interface{}) {
// GetVar gets a value from the context's variable table with
// the given key. It returns the value if found, and true if
// it found a value with that key; false otherwise.
func (cx Connection) GetVar(key string) interface{} {
func (cx *Connection) GetVar(key string) interface{} {
varMap, ok := cx.Context.Value(VarsCtxKey).(map[string]interface{})
if !ok {
return nil
}
return varMap[key]
}

// MatchingBytes returns all bytes currently available for matching. This is only intended for reading.
// Do not write into the slice it's a view of the internal buffer and you will likely mess up the connection.
ydylla marked this conversation as resolved.
Show resolved Hide resolved
func (cx *Connection) MatchingBytes() []byte {
return cx.buf[cx.offset:]
}

var (
// VarsCtxKey is the key used to store the variables table
// in a Connection's context.
Expand All @@ -187,8 +229,15 @@ var (
listenerCtxKey caddy.CtxKey = "listener"
)

const PrefetchChunkSize = 1024
mholt marked this conversation as resolved.
Show resolved Hide resolved

// MaxMatchingBytes is the amount of bytes that are at most prefetched during matching.
// This is probably most relevant for the http matcher since http requests do not have a size limit.
// 8 KiB should cover most use-cases and is similar to popular webservers.
const MaxMatchingBytes = 8 * 1024

var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
return make([]byte, 0, PrefetchChunkSize)
},
}
21 changes: 13 additions & 8 deletions layer4/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
"go.uber.org/zap"
)

func TestConnection_RecordAndRewind(t *testing.T) {
func TestConnection_FreezeAndUnfreeze(t *testing.T) {
in, out := net.Pipe()
defer in.Close()
defer out.Close()

cx := WrapConnection(out, &bytes.Buffer{}, zap.NewNop())
cx := WrapConnection(out, []byte{}, zap.NewNop())
defer cx.Close()

matcherData := []byte("foo")
Expand All @@ -26,9 +26,14 @@ func TestConnection_RecordAndRewind(t *testing.T) {
in.Write(consumeData)
}()

// 1st matcher
// prefetch like server handler would
err := cx.prefetch()
if err != nil {
t.Fatal(err)
}

cx.record()
// 1st matcher
cx.freeze()

n, err := cx.Read(buf)
if err != nil {
Expand All @@ -41,11 +46,11 @@ func TestConnection_RecordAndRewind(t *testing.T) {
t.Fatalf("expected %s but received %s", matcherData, buf)
}

cx.rewind()
cx.unfreeze()

// 2nd matcher (reads same data)

cx.record()
cx.freeze()

n, err = cx.Read(buf)
if err != nil {
Expand All @@ -58,9 +63,9 @@ func TestConnection_RecordAndRewind(t *testing.T) {
t.Fatalf("expected %s but received %s", matcherData, buf)
}

cx.rewind()
cx.unfreeze()

// 1st consumer (no record call)
// 1st consumer (no freeze call)

n, err = cx.Read(buf)
if err != nil {
Expand Down
11 changes: 10 additions & 1 deletion layer4/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ type HandlerFunc func(*Connection) error
// Handle handles a connection; it implements the Handler interface.
func (h HandlerFunc) Handle(cx *Connection) error { return h(cx) }

// NextHandlerFunc can turn a function into a NextHandler type.
type NextHandlerFunc func(cx *Connection, next Handler) error

func (h NextHandlerFunc) Handle(cx *Connection, next Handler) error { return h(cx, next) }

// nopHandler is a connection handler that does nothing with the
// connection, not even reading from it; it simply returns. It is
// the default end of all handler chains.
Expand All @@ -75,9 +80,13 @@ type nopHandler struct{}

func (nopHandler) Handle(_ *Connection) error { return nil }

type nopNextHandler struct{}

func (nopNextHandler) Handle(cx *Connection, next Handler) error { return next.Handle(cx) }

// listenerHandler is a connection handler that pipe incoming connection to channel as a listener wrapper
type listenerHandler struct{}

func (listenerHandler) Handle(conn *Connection) error {
func (listenerHandler) Handle(conn *Connection, _ Handler) error {
return conn.Context.Value(listenerCtxKey).(*listener).pipeConnection(conn)
}
19 changes: 13 additions & 6 deletions layer4/listener.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package layer4

import (
"bytes"
"context"
"crypto/tls"
"errors"
"github.com/caddyserver/caddy/v2"
"go.uber.org/zap"
"net"
"runtime"
"sync"
"time"

"github.com/caddyserver/caddy/v2"
"go.uber.org/zap"
)

func init() {
Expand All @@ -22,6 +22,9 @@ type ListenerWrapper struct {
// Routes express composable logic for handling byte streams.
Routes RouteList `json:"routes,omitempty"`

// Maximum time connections have to complete the matching phase (the first terminal handler is matched). Default: 3s.
MatchingTimeout caddy.Duration `json:"matching_timeout,omitempty"`

compiledRoute Handler

logger *zap.Logger
Expand All @@ -41,11 +44,15 @@ func (lw *ListenerWrapper) Provision(ctx caddy.Context) error {
lw.ctx = ctx
lw.logger = ctx.Logger()

if lw.MatchingTimeout <= 0 {
lw.MatchingTimeout = caddy.Duration(MatchingTimeoutDefault)
}

err := lw.Routes.Provision(ctx)
if err != nil {
return err
}
lw.compiledRoute = lw.Routes.Compile(listenerHandler{}, lw.logger)
lw.compiledRoute = lw.Routes.Compile(lw.logger, time.Duration(lw.MatchingTimeout), listenerHandler{})

return nil
}
Expand Down Expand Up @@ -116,8 +123,8 @@ func (l *listener) handle(conn net.Conn) {
}
}()

buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
buf := bufPool.Get().([]byte)
buf = buf[:0]
defer bufPool.Put(buf)

cx := WrapConnection(conn, buf, l.logger)
Expand Down
4 changes: 2 additions & 2 deletions layer4/matchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ type MatcherSet []ConnMatcher
// or if there are no matchers. Any error terminates matching.
func (mset MatcherSet) Match(cx *Connection) (matched bool, err error) {
for _, m := range mset {
cx.record()
cx.freeze()
matched, err = m.Match(cx)
cx.rewind()
cx.unfreeze()
if cx.Logger.Core().Enabled(zap.DebugLevel) {
matcher := "unknown"
if cm, ok := m.(caddy.Module); ok {
Expand Down
Loading
Loading