diff --git a/layer4/connection.go b/layer4/connection.go index be4973c..7bd0184 100644 --- a/layer4/connection.go +++ b/layer4/connection.go @@ -15,9 +15,8 @@ package layer4 import ( - "bytes" "context" - "io" + "errors" "net" "sync" @@ -30,7 +29,7 @@ import ( // and variable table. This function is intended for use at the start of a // connection handler chain where the underlying connection is not yet a layer4 // Connection value. -func WrapConnection(underlying net.Conn, buf *bytes.Buffer, logger *zap.Logger) *Connection { +func WrapConnection(underlying net.Conn, buf []byte, logger *zap.Logger) *Connection { repl := caddy.NewReplacer() repl.Set("l4.conn.remote_addr", underlying.RemoteAddr()) repl.Set("l4.conn.local_addr", underlying.LocalAddr()) @@ -66,31 +65,39 @@ type Connection struct { Logger *zap.Logger - buf *bytes.Buffer // stores recordings - bufReader io.Reader // used to read buf so it doesn't discard bytes - recording bool + buf []byte // stores matching data + offset int + frozenOffset int + matching bool bytesRead, bytesWritten uint64 } +var ErrConsumedAllPrefetchedBytes = errors.New("consumed all prefetched bytes") +var ErrMatchingBufferFull = errors.New("matching buffer is full") + // Read implements io.Reader in such a way that reads first // deplete any associated buffer from the prior recording, // and once depleted (or if there isn't one), it continues // reading from the underlying connection. func (cx *Connection) Read(p []byte) (n int, err error) { + // if we are matching and consumed the buffer exit with error + if cx.matching && (len(cx.buf) == 0 || len(cx.buf) == cx.offset) { + return 0, ErrConsumedAllPrefetchedBytes + } + // if there is a buffer we should read from, start // with that; we only read from the underlying conn // after the buffer has been "depleted" - if cx.bufReader != nil { - n, err = cx.bufReader.Read(p) - if err == io.EOF { - cx.bufReader = nil - err = nil - } - // prevent first read from returning 0 bytes because of empty bufReader - if !(n == 0 && err == nil) { - return + if len(cx.buf) > 0 && cx.offset < len(cx.buf) { + n := copy(p, cx.buf[cx.offset:]) + cx.offset += n + if !cx.matching && cx.offset == len(cx.buf) { + // if we are not in matching mode reset buf automatically after it was consumed + cx.offset = 0 + cx.buf = cx.buf[:0] } + return n, nil } // buffer has been "depleted" so read from @@ -98,19 +105,6 @@ func (cx *Connection) Read(p []byte) (n int, err error) { n, err = cx.Conn.Read(p) cx.bytesRead += uint64(n) - if !cx.recording { - return - } - - // since we're recording at this point, anything that - // was read needs to be written to the buffer, even - // if there was an error - if n > 0 { - if nw, errw := cx.buf.Write(p[:n]); errw != nil { - return nw, errw - } - } - return } @@ -130,33 +124,75 @@ func (cx *Connection) Wrap(conn net.Conn) *Connection { Context: cx.Context, Logger: cx.Logger, buf: cx.buf, - bufReader: cx.bufReader, - recording: cx.recording, + offset: cx.offset, + matching: cx.matching, bytesRead: cx.bytesRead, bytesWritten: cx.bytesWritten, } } -// record starts recording the stream into cx.buf. It also creates a reader -// to read from the buffer but not to discard any byte. -func (cx *Connection) record() { - cx.recording = true - cx.bufReader = bytes.NewReader(cx.buf.Bytes()) // Don't discard bytes. +// prefetch tries to read all bytes that a client initially sent us without blocking. +func (cx *Connection) prefetch() (err error) { + var n int + var tmp []byte + + for len(cx.buf) < MaxMatchingBytes { + free := cap(cx.buf) - len(cx.buf) + if free >= prefetchChunkSize { + n, err = cx.Conn.Read(cx.buf[len(cx.buf) : len(cx.buf)+prefetchChunkSize]) + cx.buf = cx.buf[:len(cx.buf)+n] + } else { + if tmp == nil { + tmp = bufPool.Get().([]byte) + tmp = tmp[:prefetchChunkSize] + defer bufPool.Put(tmp) + } + n, err = cx.Conn.Read(tmp) + cx.buf = append(cx.buf, tmp[:n]...) + } + + cx.bytesRead += uint64(n) + + if err != nil { + return err + } + + if n < prefetchChunkSize { + break + } + } + + if cx.Logger.Core().Enabled(zap.DebugLevel) { + cx.Logger.Debug("prefetched", + zap.String("remote", cx.RemoteAddr().String()), + zap.Int("bytes", len(cx.buf)), + ) + } + + if len(cx.buf) >= MaxMatchingBytes { + return ErrMatchingBufferFull + } + + return nil +} + +// freeze activates the matching mode that only reads from cx.buf. +func (cx *Connection) freeze() { + cx.matching = true + cx.frozenOffset = cx.offset } -// rewind stops recording and creates a reader for the -// buffer so that the next reads from an associated -// recordableConn come from the buffer first, then -// continue with the underlying conn. -func (cx *Connection) rewind() { - cx.recording = false - cx.bufReader = cx.buf // Actually consume bytes. +// unfreeze stops the matching mode and resets the buffer offset +// so that the next reads come from the buffer first. +func (cx *Connection) unfreeze() { + cx.matching = false + cx.offset = cx.frozenOffset } // SetVar sets a value in the context's variable table with // the given key. It overwrites any previous value with the // same key. -func (cx Connection) SetVar(key string, value interface{}) { +func (cx *Connection) SetVar(key string, value interface{}) { varMap, ok := cx.Context.Value(VarsCtxKey).(map[string]interface{}) if !ok { return @@ -167,7 +203,7 @@ func (cx Connection) SetVar(key string, value interface{}) { // GetVar gets a value from the context's variable table with // the given key. It returns the value if found, and true if // it found a value with that key; false otherwise. -func (cx Connection) GetVar(key string) interface{} { +func (cx *Connection) GetVar(key string) interface{} { varMap, ok := cx.Context.Value(VarsCtxKey).(map[string]interface{}) if !ok { return nil @@ -175,6 +211,12 @@ func (cx Connection) GetVar(key string) interface{} { return varMap[key] } +// MatchingBytes returns all bytes currently available for matching. This is only intended for reading. +// Do not write into the slice. It's a view of the internal buffer and you will likely mess up the connection. +func (cx *Connection) MatchingBytes() []byte { + return cx.buf[cx.offset:] +} + var ( // VarsCtxKey is the key used to store the variables table // in a Connection's context. @@ -187,8 +229,15 @@ var ( listenerCtxKey caddy.CtxKey = "listener" ) +const prefetchChunkSize = 1024 + +// MaxMatchingBytes is the amount of bytes that are at most prefetched during matching. +// This is probably most relevant for the http matcher since http requests do not have a size limit. +// 8 KiB should cover most use-cases and is similar to popular webservers. +const MaxMatchingBytes = 8 * 1024 + var bufPool = sync.Pool{ New: func() interface{} { - return new(bytes.Buffer) + return make([]byte, 0, prefetchChunkSize) }, } diff --git a/layer4/connection_test.go b/layer4/connection_test.go index e610582..d7626e6 100644 --- a/layer4/connection_test.go +++ b/layer4/connection_test.go @@ -8,12 +8,12 @@ import ( "go.uber.org/zap" ) -func TestConnection_RecordAndRewind(t *testing.T) { +func TestConnection_FreezeAndUnfreeze(t *testing.T) { in, out := net.Pipe() defer in.Close() defer out.Close() - cx := WrapConnection(out, &bytes.Buffer{}, zap.NewNop()) + cx := WrapConnection(out, []byte{}, zap.NewNop()) defer cx.Close() matcherData := []byte("foo") @@ -26,9 +26,14 @@ func TestConnection_RecordAndRewind(t *testing.T) { in.Write(consumeData) }() - // 1st matcher + // prefetch like server handler would + err := cx.prefetch() + if err != nil { + t.Fatal(err) + } - cx.record() + // 1st matcher + cx.freeze() n, err := cx.Read(buf) if err != nil { @@ -41,11 +46,11 @@ func TestConnection_RecordAndRewind(t *testing.T) { t.Fatalf("expected %s but received %s", matcherData, buf) } - cx.rewind() + cx.unfreeze() // 2nd matcher (reads same data) - cx.record() + cx.freeze() n, err = cx.Read(buf) if err != nil { @@ -58,9 +63,9 @@ func TestConnection_RecordAndRewind(t *testing.T) { t.Fatalf("expected %s but received %s", matcherData, buf) } - cx.rewind() + cx.unfreeze() - // 1st consumer (no record call) + // 1st consumer (no freeze call) n, err = cx.Read(buf) if err != nil { diff --git a/layer4/handlers.go b/layer4/handlers.go index 14ac43c..d41e5dd 100644 --- a/layer4/handlers.go +++ b/layer4/handlers.go @@ -60,6 +60,11 @@ type HandlerFunc func(*Connection) error // Handle handles a connection; it implements the Handler interface. func (h HandlerFunc) Handle(cx *Connection) error { return h(cx) } +// NextHandlerFunc can turn a function into a NextHandler type. +type NextHandlerFunc func(cx *Connection, next Handler) error + +func (h NextHandlerFunc) Handle(cx *Connection, next Handler) error { return h(cx, next) } + // nopHandler is a connection handler that does nothing with the // connection, not even reading from it; it simply returns. It is // the default end of all handler chains. @@ -75,9 +80,13 @@ type nopHandler struct{} func (nopHandler) Handle(_ *Connection) error { return nil } +type nopNextHandler struct{} + +func (nopNextHandler) Handle(cx *Connection, next Handler) error { return next.Handle(cx) } + // listenerHandler is a connection handler that pipe incoming connection to channel as a listener wrapper type listenerHandler struct{} -func (listenerHandler) Handle(conn *Connection) error { +func (listenerHandler) Handle(conn *Connection, _ Handler) error { return conn.Context.Value(listenerCtxKey).(*listener).pipeConnection(conn) } diff --git a/layer4/listener.go b/layer4/listener.go index 81c4ae1..de60abe 100644 --- a/layer4/listener.go +++ b/layer4/listener.go @@ -1,16 +1,16 @@ package layer4 import ( - "bytes" "context" "crypto/tls" "errors" - "github.com/caddyserver/caddy/v2" - "go.uber.org/zap" "net" "runtime" "sync" "time" + + "github.com/caddyserver/caddy/v2" + "go.uber.org/zap" ) func init() { @@ -22,6 +22,9 @@ type ListenerWrapper struct { // Routes express composable logic for handling byte streams. Routes RouteList `json:"routes,omitempty"` + // Maximum time connections have to complete the matching phase (the first terminal handler is matched). Default: 3s. + MatchingTimeout caddy.Duration `json:"matching_timeout,omitempty"` + compiledRoute Handler logger *zap.Logger @@ -41,11 +44,15 @@ func (lw *ListenerWrapper) Provision(ctx caddy.Context) error { lw.ctx = ctx lw.logger = ctx.Logger() + if lw.MatchingTimeout <= 0 { + lw.MatchingTimeout = caddy.Duration(MatchingTimeoutDefault) + } + err := lw.Routes.Provision(ctx) if err != nil { return err } - lw.compiledRoute = lw.Routes.Compile(listenerHandler{}, lw.logger) + lw.compiledRoute = lw.Routes.Compile(lw.logger, time.Duration(lw.MatchingTimeout), listenerHandler{}) return nil } @@ -116,8 +123,8 @@ func (l *listener) handle(conn net.Conn) { } }() - buf := bufPool.Get().(*bytes.Buffer) - buf.Reset() + buf := bufPool.Get().([]byte) + buf = buf[:0] defer bufPool.Put(buf) cx := WrapConnection(conn, buf, l.logger) diff --git a/layer4/matchers.go b/layer4/matchers.go index 0b3a9ed..14b6006 100644 --- a/layer4/matchers.go +++ b/layer4/matchers.go @@ -48,9 +48,9 @@ type MatcherSet []ConnMatcher // or if there are no matchers. Any error terminates matching. func (mset MatcherSet) Match(cx *Connection) (matched bool, err error) { for _, m := range mset { - cx.record() + cx.freeze() matched, err = m.Match(cx) - cx.rewind() + cx.unfreeze() if cx.Logger.Core().Enabled(zap.DebugLevel) { matcher := "unknown" if cm, ok := m.(caddy.Module); ok { diff --git a/layer4/matchers_test.go b/layer4/matchers_test.go index 4391a67..0f447fc 100644 --- a/layer4/matchers_test.go +++ b/layer4/matchers_test.go @@ -1,7 +1,6 @@ package layer4 import ( - "bytes" "net" "testing" @@ -70,7 +69,6 @@ func TestNotMatcher(t *testing.T) { localAddr: dummyAddr{ip: "127.0.0.1", network: "tcp"}, remoteAddr: dummyAddr{ip: "127.0.0.1", network: "tcp"}, }, - buf: &bytes.Buffer{}, Logger: zap.NewNop(), }, matcher: MatchNot{ @@ -89,7 +87,6 @@ func TestNotMatcher(t *testing.T) { localAddr: dummyAddr{ip: "127.0.0.1", network: "tcp"}, remoteAddr: dummyAddr{ip: "192.168.0.1", network: "tcp"}, }, - buf: &bytes.Buffer{}, Logger: zap.NewNop(), }, matcher: MatchNot{ @@ -108,7 +105,6 @@ func TestNotMatcher(t *testing.T) { localAddr: dummyAddr{ip: "127.0.0.1", network: "tcp"}, remoteAddr: dummyAddr{ip: "192.168.0.1", network: "tcp"}, }, - buf: &bytes.Buffer{}, Logger: zap.NewNop(), }, matcher: MatchNot{ @@ -130,7 +126,6 @@ func TestNotMatcher(t *testing.T) { localAddr: dummyAddr{ip: "127.0.0.1", network: "tcp"}, remoteAddr: dummyAddr{ip: "172.16.0.1", network: "tcp"}, }, - buf: &bytes.Buffer{}, Logger: zap.NewNop(), }, matcher: MatchNot{ @@ -152,7 +147,6 @@ func TestNotMatcher(t *testing.T) { localAddr: dummyAddr{ip: "192.168.0.1", network: "tcp"}, remoteAddr: dummyAddr{ip: "192.168.0.1", network: "tcp"}, }, - buf: &bytes.Buffer{}, Logger: zap.NewNop(), }, matcher: MatchNot{ @@ -174,7 +168,6 @@ func TestNotMatcher(t *testing.T) { localAddr: dummyAddr{ip: "127.0.0.1", network: "tcp"}, remoteAddr: dummyAddr{ip: "172.16.0.1", network: "tcp"}, }, - buf: &bytes.Buffer{}, Logger: zap.NewNop(), }, matcher: MatchNot{ @@ -194,7 +187,6 @@ func TestNotMatcher(t *testing.T) { localAddr: dummyAddr{ip: "127.0.0.1", network: "tcp"}, remoteAddr: dummyAddr{ip: "192.168.0.1", network: "tcp"}, }, - buf: &bytes.Buffer{}, Logger: zap.NewNop(), }, matcher: MatchNot{ diff --git a/layer4/routes.go b/layer4/routes.go index 673341c..62ebfc3 100644 --- a/layer4/routes.go +++ b/layer4/routes.go @@ -16,7 +16,10 @@ package layer4 import ( "encoding/json" + "errors" "fmt" + "os" + "time" "github.com/caddyserver/caddy/v2" "go.uber.org/zap" @@ -43,6 +46,8 @@ type Route struct { middleware []Middleware } +var ErrMatchingTimeout = errors.New("aborted matching according to timeout") + // Provision sets up a route. func (r *Route) Provision(ctx caddy.Context) error { // matchers @@ -62,7 +67,8 @@ func (r *Route) Provision(ctx caddy.Context) error { } var handlers Handlers for _, mod := range mods.([]interface{}) { - handlers = append(handlers, mod.(NextHandler)) + handler := mod.(NextHandler) + handlers = append(handlers, handler) } for _, midhandler := range handlers { r.middleware = append(r.middleware, wrapHandler(midhandler)) @@ -92,75 +98,71 @@ func (routes RouteList) Provision(ctx caddy.Context) error { // Compile prepares a middleware chain from the route list. // This should only be done once: after all the routes have // been provisioned, and before the server loop begins. -func (routes RouteList) Compile(next Handler, logger *zap.Logger) Handler { - mid := make([]Middleware, 0, len(routes)) - for _, route := range routes { - mid = append(mid, wrapRoute(route, logger)) - } - stack := next - for i := len(mid) - 1; i >= 0; i-- { - stack = mid[i](stack) - } - return stack -} - -// wrapRoute wraps route with a middleware and handler so that it can -// be chained in and defer evaluation of its matchers to request-time. -// Like wrapMiddleware, it is vital that this wrapping takes place in -// its own stack frame so as to not overwrite the reference to the -// intended route by looping and changing the reference each time. -func wrapRoute(route *Route, logger *zap.Logger) Middleware { - return func(next Handler) Handler { - return HandlerFunc(func(cx *Connection) error { - // TODO: Update this comment, it seems we've moved the copy into the handler? - // copy the next handler (it's an interface, so it's just - // a very lightweight copy of a pointer); this is important - // because this is a closure to the func below, which - // re-assigns the value as it compiles the middleware stack; - // if we don't make this copy, we'd affect the underlying - // pointer for all future request (yikes); we could - // alternatively solve this by moving the func below out of - // this closure and into a standalone package-level func, - // but I just thought this made more sense - nextCopy := next - - // route must match at least one of the matcher sets - matched, err := route.matcherSets.AnyMatch(cx) +func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duration, next NextHandler) Handler { + return HandlerFunc(func(cx *Connection) error { + deadline := time.Now().Add(matchingTimeout) + router: + // timeout matching to protect against malicious or very slow clients + err := cx.Conn.SetReadDeadline(deadline) + if err != nil { + return err + } + for { // retry prefetching and matching routes until timeout + err = cx.prefetch() if err != nil { - logger.Error("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) + logFunc := logger.Error + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrMatchingTimeout + logFunc = logger.Warn + } + logFunc("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) return nil // return nil so the error does not get logged again } - if !matched { - return nextCopy.Handle(cx) + for _, route := range routes { + // A route must match at least one of the matcher sets + matched, err := route.matcherSets.AnyMatch(cx) + if errors.Is(err, ErrConsumedAllPrefetchedBytes) { + continue // ignore and try next route + } + if err != nil { + logger.Error("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) + return nil + } + if matched { + // remove deadline after we matched + err = cx.Conn.SetReadDeadline(time.Time{}) + if err != nil { + return err + } + + isTerminal := true + lastHandler := HandlerFunc(func(conn *Connection) error { + // Catch potentially wrapped connection to use it as input for the next round of route matching. + // This is for example required for matchers after a tls handler. + cx = conn + // If this handler is called all handlers before where not terminal + isTerminal = false + return nil + }) + // compile the route handler stack with lastHandler being called last + handler := wrapHandler(next)(lastHandler) + for i := len(route.middleware) - 1; i >= 0; i-- { + handler = route.middleware[i](handler) + } + err = handler.Handle(cx) + if err != nil { + return err + } + + // If handler is terminal we stop routing, + // otherwise we jump back to the start of the routing loop to peel of more protocol layers. + if isTerminal { + return nil + } else { + goto router + } + } } - - // TODO: other routing features? - - // // if route is part of a group, ensure only the - // // first matching route in the group is applied - // if route.Group != "" { - // groups := req.Context().Value(routeGroupCtxKey).(map[string]struct{}) - - // if _, ok := groups[route.Group]; ok { - // // this group has already been - // // satisfied by a matching route - // return nextCopy.ServeHTTP(rw, req) - // } - - // // this matching route satisfies the group - // groups[route.Group] = struct{}{} - // } - - // // make terminal routes terminate - // if route.Terminal { - // nextCopy = emptyHandler - // } - - // compile this route's handler stack - for i := len(route.middleware) - 1; i >= 0; i-- { - nextCopy = route.middleware[i](nextCopy) - } - return nextCopy.Handle(cx) - }) - } + } + }) } diff --git a/layer4/routes_test.go b/layer4/routes_test.go new file mode 100644 index 0000000..1db54c7 --- /dev/null +++ b/layer4/routes_test.go @@ -0,0 +1,66 @@ +package layer4 + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/caddyserver/caddy/v2" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +func TestMatchingTimeoutWorks(t *testing.T) { + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + + routes := RouteList{&Route{}} + + err := routes.Provision(ctx) + if err != nil { + t.Fatalf("provision failed | %s", err) + } + + matched := false + loggerCore, logs := observer.New(zapcore.WarnLevel) + compiledRoutes := routes.Compile(zap.New(loggerCore), 5*time.Millisecond, + NextHandlerFunc(func(con *Connection, next Handler) error { + matched = true + return next.Handle(con) + })) + + in, out := net.Pipe() + defer in.Close() + defer out.Close() + + cx := WrapConnection(out, []byte{}, zap.NewNop()) + defer cx.Close() + + err = compiledRoutes.Handle(cx) + if err != nil { + t.Fatalf("handle failed | %s", err) + } + + // verify the matching aborted error was logged + if logs.Len() != 1 { + t.Fatalf("logs should contain 1 entry but has %d", logs.Len()) + } + logEntry := logs.All()[0] + if logEntry.Level != zapcore.WarnLevel { + t.Fatalf("wrong log level | %s", logEntry.Level) + } + if logEntry.Message != "matching connection" { + t.Fatalf("wrong log message | %s", logEntry.Message) + } + if !(logEntry.Context[1].Key == "error" && errors.Is(logEntry.Context[1].Interface.(error), ErrMatchingTimeout)) { + t.Fatalf("wrong error | %v", logEntry.Context[1].Interface) + } + + // since matching failed no handler should be called + if matched { + t.Fatal("handler was called but should not") + } +} diff --git a/layer4/server.go b/layer4/server.go index 6f1609f..b238238 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -25,6 +25,8 @@ import ( "go.uber.org/zap" ) +const MatchingTimeoutDefault = 3 * time.Second + // Server represents a Caddy layer4 server. type Server struct { // The network address to bind to. Any Caddy network address @@ -35,6 +37,9 @@ type Server struct { // Routes express composable logic for handling byte streams. Routes RouteList `json:"routes,omitempty"` + // Maximum time connections have to complete the matching phase (the first terminal handler is matched). Default: 3s. + MatchingTimeout caddy.Duration `json:"matching_timeout,omitempty"` + logger *zap.Logger listenAddrs []caddy.NetworkAddress compiledRoute Handler @@ -44,6 +49,10 @@ type Server struct { func (s *Server) Provision(ctx caddy.Context, logger *zap.Logger) error { s.logger = logger + if s.MatchingTimeout <= 0 { + s.MatchingTimeout = caddy.Duration(MatchingTimeoutDefault) + } + for i, address := range s.Listen { addr, err := caddy.ParseNetworkAddress(address) if err != nil { @@ -56,7 +65,7 @@ func (s *Server) Provision(ctx caddy.Context, logger *zap.Logger) error { if err != nil { return err } - s.compiledRoute = s.Routes.Compile(nopHandler{}, s.logger) + s.compiledRoute = s.Routes.Compile(s.logger, time.Duration(s.MatchingTimeout), nopNextHandler{}) return nil } @@ -99,8 +108,8 @@ func (s Server) servePacket(pc net.PacketConn) error { func (s Server) handle(conn net.Conn) { defer conn.Close() - buf := bufPool.Get().(*bytes.Buffer) - buf.Reset() + buf := bufPool.Get().([]byte) + buf = buf[:0] defer bufPool.Put(buf) cx := WrapConnection(conn, buf, s.logger) diff --git a/modules/l4http/httpmatcher.go b/modules/l4http/httpmatcher.go index 7bfdbc2..c80fa7b 100644 --- a/modules/l4http/httpmatcher.go +++ b/modules/l4http/httpmatcher.go @@ -16,17 +16,19 @@ package l4http import ( "bufio" + "bytes" "encoding/json" "fmt" + "io" + "net/http" + "net/url" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/mholt/caddy-l4/layer4" "github.com/mholt/caddy-l4/modules/l4tls" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" - "io" - "net/http" - "net/url" ) func init() { @@ -78,11 +80,18 @@ func (m MatchHTTP) Match(cx *layer4.Connection) (bool, error) { req, ok := cx.GetVar("http_request").(*http.Request) if !ok { var err error - bufReader := bufio.NewReader(cx) + + data := cx.MatchingBytes() + if !m.isHttp(data) { + return false, nil + } + + // use bufio reader which exactly matches the size of prefetched data, + // to not trigger all bytes consumed error + bufReader := bufio.NewReaderSize(cx, len(data)) req, err = http.ReadRequest(bufReader) if err != nil { - // TODO: find a way to distinguish actual errors from mismatches - return false, nil + return false, err } // check if req is a http2 request made with prior knowledge and if so parse it @@ -113,6 +122,23 @@ func (m MatchHTTP) Match(cx *layer4.Connection) (bool, error) { return m.matcherSets.AnyMatch(req), nil } +func (m MatchHTTP) isHttp(data []byte) bool { + // try to find the end of a http request line, for example " HTTP/1.1\r\n" + i := bytes.IndexByte(data, 0x0a) // find first new line + if i < 10 { + return false + } + // assume only \n line ending + start := i - 9 // position of space in front of HTTP + end := i - 3 // cut off version number "1.1" or "2.0" + // if we got a correct \r\n line ending shift the calculated start & end to the left + if data[i-1] == 0x0d { + start -= 1 + end -= 1 + } + return bytes.Compare(data[start:end], []byte(" HTTP/")) == 0 +} + // Parses information from a http2 request with prior knowledge (RFC 7540 Section 3.4) func (m MatchHTTP) handleHttp2WithPriorKnowledge(reader io.Reader, req *http.Request) error { // Does req contain a valid http2 magic? diff --git a/modules/l4http/httpmatcher_test.go b/modules/l4http/httpmatcher_test.go index 555658c..19d1e0a 100644 --- a/modules/l4http/httpmatcher_test.go +++ b/modules/l4http/httpmatcher_test.go @@ -1,15 +1,13 @@ package l4http import ( - "bytes" "context" "crypto/tls" "encoding/base64" "encoding/json" - "io" "net" - "sync" "testing" + "time" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" @@ -24,22 +22,13 @@ func assertNoError(t *testing.T, err error) { } } -func httpMatchTester(t *testing.T, matcherSets caddyhttp.RawMatcherSets, data []byte) (bool, error) { - wg := &sync.WaitGroup{} +func httpMatchTester(t *testing.T, matchers json.RawMessage, data []byte) (bool, error) { in, out := net.Pipe() - defer func() { - wg.Wait() - _ = in.Close() - _ = out.Close() - }() + defer in.Close() + defer out.Close() - cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(in, make([]byte, 0), zap.NewNop()) go func() { - wg.Add(1) - defer func() { - wg.Done() - _ = out.Close() - }() _, err := out.Write(data) assertNoError(t, err) }() @@ -47,13 +36,23 @@ func httpMatchTester(t *testing.T, matcherSets caddyhttp.RawMatcherSets, data [] ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) defer cancel() - matcher := MatchHTTP{MatcherSetsRaw: matcherSets} - err := matcher.Provision(ctx) + routes := layer4.RouteList{&layer4.Route{ + MatcherSetsRaw: caddyhttp.RawMatcherSets{ + caddy.ModuleMap{"http": matchers}, + }, + }} + err := routes.Provision(ctx) assertNoError(t, err) - matched, err := matcher.Match(cx) + matched := false + compiledRoute := routes.Compile(zap.NewNop(), 10*time.Millisecond, + layer4.NextHandlerFunc(func(con *layer4.Connection, _ layer4.Handler) error { + matched = true + return nil + })) - _, _ = io.Copy(io.Discard, in) + err = compiledRoute.Handle(cx) + assertNoError(t, err) return matched, err } @@ -62,43 +61,43 @@ func TestHttp1Matching(t *testing.T) { http1RequestExample := []byte("GET /foo/bar?aaa=bbb HTTP/1.1\nHost: localhost:10443\nUser-Agent: curl/7.82.0\nAccept: */*\n\n") for _, tc := range []struct { - name string - matcherSets caddyhttp.RawMatcherSets - data []byte + name string + matchers json.RawMessage + data []byte }{ { - name: "match-by-host", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"host": json.RawMessage("[\"localhost\"]")}}, - data: http1RequestExample, + name: "match-by-host", + matchers: json.RawMessage("[{\"host\":[\"localhost\"]}]"), + data: http1RequestExample, }, { - name: "match-by-method", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"method": json.RawMessage("[\"GET\"]")}}, - data: http1RequestExample, + name: "match-by-method", + matchers: json.RawMessage("[{\"method\":[\"GET\"]}]"), + data: http1RequestExample, }, { - name: "match-by-path", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"path": json.RawMessage("[\"/foo/bar\"]")}}, - data: http1RequestExample, + name: "match-by-path", + matchers: json.RawMessage("[{\"path\":[\"/foo/bar\"]}]"), + data: http1RequestExample, }, { - name: "match-by-query", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"query": json.RawMessage("{\"aaa\":[\"bbb\"]}")}}, - data: http1RequestExample, + name: "match-by-query", + matchers: json.RawMessage("[{\"query\":{\"aaa\":[\"bbb\"]}}]"), + data: http1RequestExample, }, { - name: "match-by-header", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"header": json.RawMessage("{\"user-agent\":[\"curl*\"]}")}}, - data: http1RequestExample, + name: "match-by-header", + matchers: json.RawMessage("[{\"header\":{\"user-agent\":[\"curl*\"]}}]"), + data: http1RequestExample, }, { - name: "match-by-protocol", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"protocol": json.RawMessage("\"http\"")}}, - data: http1RequestExample, + name: "match-by-protocol", + matchers: json.RawMessage("[{\"protocol\":\"http\"}]"), + data: http1RequestExample, }, } { t.Run(tc.name, func(t *testing.T) { - matched, err := httpMatchTester(t, tc.matcherSets, tc.data) + matched, err := httpMatchTester(t, tc.matchers, tc.data) assertNoError(t, err) if !matched { t.Errorf("matcher did not match") @@ -115,74 +114,74 @@ func TestHttp2Matching(t *testing.T) { assertNoError(t, err) for _, tc := range []struct { - name string - matcherSets caddyhttp.RawMatcherSets - data []byte + name string + matchers json.RawMessage + data []byte }{ { - name: "match-by-host", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"host": json.RawMessage("[\"localhost\"]")}}, - data: http2PriorKnowledgeRequestExample, + name: "match-by-host", + matchers: json.RawMessage("[{\"host\":[\"localhost\"]}]"), + data: http2PriorKnowledgeRequestExample, }, { - name: "match-by-method", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"method": json.RawMessage("[\"GET\"]")}}, - data: http2PriorKnowledgeRequestExample, + name: "match-by-method", + matchers: json.RawMessage("[{\"method\":[\"GET\"]}]"), + data: http2PriorKnowledgeRequestExample, }, { - name: "match-by-path", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"path": json.RawMessage("[\"/foo/bar\"]")}}, - data: http2PriorKnowledgeRequestExample, + name: "match-by-path", + matchers: json.RawMessage("[{\"path\":[\"/foo/bar\"]}]"), + data: http2PriorKnowledgeRequestExample, }, { - name: "match-by-query", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"query": json.RawMessage("{\"aaa\":[\"bbb\"]}")}}, - data: http2PriorKnowledgeRequestExample, + name: "match-by-query", + matchers: json.RawMessage("[{\"query\":{\"aaa\":[\"bbb\"]}}]"), + data: http2PriorKnowledgeRequestExample, }, { - name: "match-by-header", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"header": json.RawMessage("{\"user-agent\":[\"curl*\"]}")}}, - data: http2PriorKnowledgeRequestExample, + name: "match-by-header", + matchers: json.RawMessage("[{\"header\":{\"user-agent\":[\"curl*\"]}}]"), + data: http2PriorKnowledgeRequestExample, }, { - name: "match-by-protocol", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"protocol": json.RawMessage("\"http\"")}}, - data: http2PriorKnowledgeRequestExample, + name: "match-by-protocol", + matchers: json.RawMessage("[{\"protocol\":\"http\"}]"), + data: http2PriorKnowledgeRequestExample, }, { - name: "upgrade-match-by-host", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"host": json.RawMessage("[\"localhost\"]")}}, - data: http2UpgradeRequestExample, + name: "upgrade-match-by-host", + matchers: json.RawMessage("[{\"host\":[\"localhost\"]}]"), + data: http2UpgradeRequestExample, }, { - name: "upgrade-match-by-method", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"method": json.RawMessage("[\"GET\"]")}}, - data: http2UpgradeRequestExample, + name: "upgrade-match-by-method", + matchers: json.RawMessage("[{\"method\":[\"GET\"]}]"), + data: http2UpgradeRequestExample, }, { - name: "upgrade-match-by-path", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"path": json.RawMessage("[\"/foo/bar\"]")}}, - data: http2UpgradeRequestExample, + name: "upgrade-match-by-path", + matchers: json.RawMessage("[{\"path\":[\"/foo/bar\"]}]"), + data: http2UpgradeRequestExample, }, { - name: "upgrade-match-by-query", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"query": json.RawMessage("{\"aaa\":[\"bbb\"]}")}}, - data: http2UpgradeRequestExample, + name: "upgrade-match-by-query", + matchers: json.RawMessage("[{\"query\":{\"aaa\":[\"bbb\"]}}]"), + data: http2UpgradeRequestExample, }, { - name: "upgrade-match-by-header", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"header": json.RawMessage("{\"user-agent\":[\"curl*\"]}")}}, - data: http2UpgradeRequestExample, + name: "upgrade-match-by-header", + matchers: json.RawMessage("[{\"header\":{\"user-agent\":[\"curl*\"]}}]"), + data: http2UpgradeRequestExample, }, { - name: "upgrade-match-by-protocol", - matcherSets: caddyhttp.RawMatcherSets{caddy.ModuleMap{"protocol": json.RawMessage("\"http\"")}}, - data: http2UpgradeRequestExample, + name: "upgrade-match-by-protocol", + matchers: json.RawMessage("[{\"protocol\":\"http\"}]"), + data: http2UpgradeRequestExample, }, } { t.Run(tc.name, func(t *testing.T) { - matched, err := httpMatchTester(t, tc.matcherSets, tc.data) + matched, err := httpMatchTester(t, tc.matchers, tc.data) assertNoError(t, err) if !matched { t.Errorf("matcher did not match") @@ -192,23 +191,31 @@ func TestHttp2Matching(t *testing.T) { } func TestHttpMatchingByProtocolWithHttps(t *testing.T) { - matcherSets := caddyhttp.RawMatcherSets{caddy.ModuleMap{"protocol": json.RawMessage("\"https\"")}} + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + + routes := layer4.RouteList{&layer4.Route{ + MatcherSetsRaw: caddyhttp.RawMatcherSets{ + caddy.ModuleMap{"http": json.RawMessage("[{\"protocol\":\"https\"}]")}, + }, + }} + + err := routes.Provision(ctx) + assertNoError(t, err) + + handlerCalled := false + compiledRoute := routes.Compile(zap.NewNop(), 100*time.Millisecond, + layer4.NextHandlerFunc(func(con *layer4.Connection, _ layer4.Handler) error { + handlerCalled = true + return nil + })) - wg := &sync.WaitGroup{} in, out := net.Pipe() - defer func() { - wg.Wait() - _ = in.Close() - _ = out.Close() - }() + defer in.Close() + defer out.Close() - cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(in, []byte{}, zap.NewNop()) go func() { - wg.Add(1) - defer func() { - wg.Done() - _ = out.Close() - }() _, err := out.Write([]byte("GET /foo/bar?aaa=bbb HTTP/1.1\nHost: localhost:10443\n\n")) assertNoError(t, err) }() @@ -216,26 +223,17 @@ func TestHttpMatchingByProtocolWithHttps(t *testing.T) { // pretend the tls handler was executed before, not an ideal test setup but better then nothing cx.SetVar("tls_connection_states", []*tls.ConnectionState{{ServerName: "localhost"}}) - ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) - defer cancel() - - matcher := MatchHTTP{MatcherSetsRaw: matcherSets} - err := matcher.Provision(ctx) + err = compiledRoute.Handle(cx) assertNoError(t, err) - - matched, err := matcher.Match(cx) - assertNoError(t, err) - if !matched { + if !handlerCalled { t.Fatalf("matcher did not match") } - - _, _ = io.Copy(io.Discard, in) } func TestHttpMatchingGarbage(t *testing.T) { - matcherSets := caddyhttp.RawMatcherSets{caddy.ModuleMap{"host": json.RawMessage("[\"localhost\"]")}} + matchers := json.RawMessage("[{\"host\":[\"localhost\"]}]") - matched, err := httpMatchTester(t, matcherSets, []byte("not a valid http request")) + matched, err := httpMatchTester(t, matchers, []byte("not a valid http request")) assertNoError(t, err) if matched { t.Fatalf("matcher did match") @@ -243,11 +241,59 @@ func TestHttpMatchingGarbage(t *testing.T) { validHttp2MagicWithoutHeadersFrame, err := base64.StdEncoding.DecodeString("UFJJICogSFRUUC8yLjANCg0KU00NCg0KAAASBAAAAAAAAAMAAABkAAQCAAAAAAIAAAAATm8gbG9uZ2VyIHZhbGlkIGh0dHAyIHJlcXVlc3QgZnJhbWVz") assertNoError(t, err) - matched, err = httpMatchTester(t, matcherSets, validHttp2MagicWithoutHeadersFrame) + matched, err = httpMatchTester(t, matchers, validHttp2MagicWithoutHeadersFrame) if matched { t.Fatalf("matcher did match") } - if err == nil || err.Error() != "unexpected EOF" { - t.Fatalf("handler did not return an error or the wrong error -> %v", err) +} + +func TestMatchHTTP_isHttp(t *testing.T) { + for _, tc := range []struct { + name string + data []byte + shouldMatch bool + }{ + { + name: "http/1.1-only-lf", + data: []byte("GET /foo/bar?aaa=bbb HTTP/1.1\nHost: localhost:10443\n\n"), + shouldMatch: true, + }, + { + name: "http/1.1-cr-lf", + data: []byte("GET /foo/bar?aaa=bbb HTTP/1.1\r\nHost: localhost:10443\r\n\r\n"), + shouldMatch: true, + }, + { + name: "http/1.0-cr-lf", + data: []byte("GET /foo/bar?aaa=bbb HTTP/1.0\r\nHost: localhost:10443\r\n\r\n"), + shouldMatch: true, + }, + { + name: "http/2.0-cr-lf", + data: []byte("PRI * HTTP/2.0\r\n\r\n"), + shouldMatch: true, + }, + { + name: "dummy-short", + data: []byte("dum\n"), + shouldMatch: false, + }, + { + name: "dummy-long", + data: []byte("dummydummydummy\n"), + shouldMatch: false, + }, + { + name: "http/1.1-without-space-in-front", + data: []byte("HTTP/1.1\n"), + shouldMatch: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + matched := MatchHTTP{}.isHttp(tc.data) + if matched != tc.shouldMatch { + t.Fatalf("test %v | matched: %v != shouldMatch: %v", tc.name, matched, tc.shouldMatch) + } + }) } } diff --git a/modules/l4proxyprotocol/handler_test.go b/modules/l4proxyprotocol/handler_test.go index 4f3610f..741592a 100644 --- a/modules/l4proxyprotocol/handler_test.go +++ b/modules/l4proxyprotocol/handler_test.go @@ -1,7 +1,6 @@ package l4proxyprotocol import ( - "bytes" "context" "io" "net" @@ -25,7 +24,7 @@ func TestProxyProtocolHandleV1(t *testing.T) { in, out := net.Pipe() defer closePipe(wg, in, out) - cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(in, []byte{}, zap.NewNop()) go func() { wg.Add(1) defer wg.Done() @@ -63,7 +62,7 @@ func TestProxyProtocolHandleV2(t *testing.T) { in, out := net.Pipe() defer closePipe(wg, in, out) - cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(in, []byte{}, zap.NewNop()) go func() { wg.Add(1) defer wg.Done() @@ -101,7 +100,7 @@ func TestProxyProtocolHandleGarbage(t *testing.T) { in, out := net.Pipe() defer closePipe(wg, in, out) - cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(in, []byte{}, zap.NewNop()) go func() { wg.Add(1) defer wg.Done() diff --git a/modules/l4proxyprotocol/matcher_test.go b/modules/l4proxyprotocol/matcher_test.go index 00dc598..c34d33c 100644 --- a/modules/l4proxyprotocol/matcher_test.go +++ b/modules/l4proxyprotocol/matcher_test.go @@ -1,7 +1,6 @@ package l4proxyprotocol import ( - "bytes" "encoding/hex" "io" "net" @@ -33,7 +32,7 @@ func TestProxyProtocolMatchV1(t *testing.T) { in, out := net.Pipe() defer closePipe(wg, in, out) - cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(in, []byte{}, zap.NewNop()) go func() { wg.Add(1) defer wg.Done() @@ -59,7 +58,7 @@ func TestProxyProtocolMatchV2(t *testing.T) { in, out := net.Pipe() defer closePipe(wg, in, out) - cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(in, []byte{}, zap.NewNop()) go func() { wg.Add(1) defer wg.Done() @@ -85,7 +84,7 @@ func TestProxyProtocolMatchGarbage(t *testing.T) { in, out := net.Pipe() defer closePipe(wg, in, out) - cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(in, []byte{}, zap.NewNop()) go func() { wg.Add(1) defer wg.Done() diff --git a/modules/l4socks/socks4_matcher_test.go b/modules/l4socks/socks4_matcher_test.go index 77fb0d8..6757b70 100644 --- a/modules/l4socks/socks4_matcher_test.go +++ b/modules/l4socks/socks4_matcher_test.go @@ -1,7 +1,6 @@ package l4socks import ( - "bytes" "context" "io" "net" @@ -78,7 +77,7 @@ func TestSocks4Matcher_Match(t *testing.T) { _ = out.Close() }() - cx := layer4.WrapConnection(out, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(out, []byte{}, zap.NewNop()) go func() { _, err := in.Write(tc.data) assertNoError(t, err) diff --git a/modules/l4socks/socks5_handler.go b/modules/l4socks/socks5_handler.go index bfefeb5..8bb1c33 100644 --- a/modules/l4socks/socks5_handler.go +++ b/modules/l4socks/socks5_handler.go @@ -2,12 +2,13 @@ package l4socks import ( "fmt" + "net" + "strings" + "github.com/caddyserver/caddy/v2" "github.com/mholt/caddy-l4/layer4" "github.com/things-go/go-socks5" "go.uber.org/zap" - "net" - "strings" ) func init() { diff --git a/modules/l4socks/socks5_handler_test.go b/modules/l4socks/socks5_handler_test.go index e6170a5..a85f02e 100644 --- a/modules/l4socks/socks5_handler_test.go +++ b/modules/l4socks/socks5_handler_test.go @@ -17,7 +17,7 @@ import ( func replay(t *testing.T, handler *Socks5Handler, expectedError string, messages [][]byte) { t.Helper() in, out := net.Pipe() - cx := layer4.WrapConnection(out, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(out, []byte{}, zap.NewNop()) defer func() { _ = in.Close() _, _ = io.Copy(io.Discard, out) diff --git a/modules/l4socks/socks5_matcher_test.go b/modules/l4socks/socks5_matcher_test.go index b45b57e..dabf47e 100644 --- a/modules/l4socks/socks5_matcher_test.go +++ b/modules/l4socks/socks5_matcher_test.go @@ -1,7 +1,6 @@ package l4socks import ( - "bytes" "context" "io" "net" @@ -55,7 +54,7 @@ func TestSocks5Matcher_Match(t *testing.T) { _ = out.Close() }() - cx := layer4.WrapConnection(out, &bytes.Buffer{}, zap.NewNop()) + cx := layer4.WrapConnection(out, []byte{}, zap.NewNop()) go func() { _, err := in.Write(tc.data) assertNoError(t, err) diff --git a/modules/l4subroute/handler.go b/modules/l4subroute/handler.go index 4a671c6..4bb67d2 100644 --- a/modules/l4subroute/handler.go +++ b/modules/l4subroute/handler.go @@ -16,6 +16,7 @@ package l4subroute import ( "fmt" + "time" "github.com/caddyserver/caddy/v2" "github.com/mholt/caddy-l4/layer4" @@ -34,6 +35,9 @@ type Handler struct { // The primary list of routes to compile and execute. Routes layer4.RouteList `json:"routes,omitempty"` + // Maximum time connections have to complete the matching phase (the first terminal handler is matched). Default: 3s. + MatchingTimeout caddy.Duration `json:"matching_timeout,omitempty"` + logger *zap.Logger } @@ -49,6 +53,10 @@ func (Handler) CaddyModule() caddy.ModuleInfo { func (h *Handler) Provision(ctx caddy.Context) error { h.logger = ctx.Logger(h) + if h.MatchingTimeout <= 0 { + h.MatchingTimeout = caddy.Duration(layer4.MatchingTimeoutDefault) + } + if h.Routes != nil { err := h.Routes.Provision(ctx) if err != nil { @@ -60,7 +68,10 @@ func (h *Handler) Provision(ctx caddy.Context) error { // Handle handles the connections. func (h *Handler) Handle(cx *layer4.Connection, next layer4.Handler) error { - subroute := h.Routes.Compile(next, h.logger) + subroute := h.Routes.Compile(h.logger, time.Duration(h.MatchingTimeout), + layer4.NextHandlerFunc(func(cx *layer4.Connection, _ layer4.Handler) error { + return next.Handle(cx) // continue with original chain after subroute + })) return subroute.Handle(cx) }