diff --git a/layer4/connection.go b/layer4/connection.go index f193a54..cff1795 100644 --- a/layer4/connection.go +++ b/layer4/connection.go @@ -136,19 +136,19 @@ func (cx *Connection) Wrap(conn net.Conn) *Connection { // prefetch tries to read all bytes that a client initially sent us without blocking. func (cx *Connection) prefetch() (err error) { var n int - var tmp []byte - for len(cx.buf) < MaxMatchingBytes { + // read once + if len(cx.buf) < MaxMatchingBytes { free := cap(cx.buf) - len(cx.buf) if free >= prefetchChunkSize { n, err = cx.Conn.Read(cx.buf[len(cx.buf) : len(cx.buf)+prefetchChunkSize]) cx.buf = cx.buf[:len(cx.buf)+n] } else { - if tmp == nil { - tmp = bufPool.Get().([]byte) - tmp = tmp[:prefetchChunkSize] - defer bufPool.Put(tmp) - } + var tmp []byte + tmp = bufPool.Get().([]byte) + tmp = tmp[:prefetchChunkSize] + defer bufPool.Put(tmp) + n, err = cx.Conn.Read(tmp) cx.buf = append(cx.buf, tmp[:n]...) } @@ -159,23 +159,17 @@ func (cx *Connection) prefetch() (err error) { return err } - if n < prefetchChunkSize { - break + if cx.Logger.Core().Enabled(zap.DebugLevel) { + cx.Logger.Debug("prefetched", + zap.String("remote", cx.RemoteAddr().String()), + zap.Int("bytes", len(cx.buf)), + ) } - } - if cx.Logger.Core().Enabled(zap.DebugLevel) { - cx.Logger.Debug("prefetched", - zap.String("remote", cx.RemoteAddr().String()), - zap.Int("bytes", len(cx.buf)), - ) - } - - if len(cx.buf) >= MaxMatchingBytes { - return ErrMatchingBufferFull + return nil } - return nil + return ErrMatchingBufferFull } // freeze activates the matching mode that only reads from cx.buf. @@ -215,6 +209,8 @@ func (cx *Connection) GetVar(key string) interface{} { // MatchingBytes returns all bytes currently available for matching. This is only intended for reading. // Do not write into the slice. It's a view of the internal buffer, and you will likely mess up the connection. +// Use of this for matching purpose should be accompanied by corresponding error value, +// ErrConsumedAllPrefetchedBytes and ErrMatchingBufferFull, if not matched. func (cx *Connection) MatchingBytes() []byte { return cx.buf[cx.offset:] } diff --git a/layer4/handlers.go b/layer4/handlers.go index d41e5dd..aa13243 100644 --- a/layer4/handlers.go +++ b/layer4/handlers.go @@ -80,13 +80,16 @@ type nopHandler struct{} func (nopHandler) Handle(_ *Connection) error { return nil } -type nopNextHandler struct{} +// forwardNextHandler will forward the handling to the next handler in the chain. +type forwardNextHandler struct{} -func (nopNextHandler) Handle(cx *Connection, next Handler) error { return next.Handle(cx) } +func (forwardNextHandler) Handle(cx *Connection, next Handler) error { + return next.Handle(cx) +} // listenerHandler is a connection handler that pipe incoming connection to channel as a listener wrapper type listenerHandler struct{} -func (listenerHandler) Handle(conn *Connection, _ Handler) error { +func (listenerHandler) Handle(conn *Connection) error { return conn.Context.Value(listenerCtxKey).(*listener).pipeConnection(conn) } diff --git a/layer4/routes.go b/layer4/routes.go index bba6f0d..9d3ff75 100644 --- a/layer4/routes.go +++ b/layer4/routes.go @@ -73,7 +73,6 @@ func (r *Route) Provision(ctx caddy.Context) error { for _, midhandler := range handlers { r.middleware = append(r.middleware, wrapHandler(midhandler)) } - return nil } @@ -95,26 +94,40 @@ func (routes RouteList) Provision(ctx caddy.Context) error { return nil } +const ( + // routes that need more data to determine the match + routeNeedsMore = iota + // routes definitely not matched + routeNotMatched + routeMatched +) + // Compile prepares a middleware chain from the route list. // This should only be done once: after all the routes have // been provisioned, and before the server loop begins. -func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duration, next NextHandler) Handler { +func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duration, next Handler) Handler { return HandlerFunc(func(cx *Connection) error { deadline := time.Now().Add(matchingTimeout) - lastMatchedRouteIdx := -1 - router: + + var ( + lastMatchedRouteIdx = -1 + lastNeedsMoreIdx = -1 + routesStatus = make(map[int]int) + matcherNeedMore bool + ) + // this loop should only be done if there are matchers that can't determine the match, + // i.e. some of the matchers returned false, ErrConsumedAllPrefetchedBytes. The index which + // the loop begins depends upon if there is a matched route. + loop: // timeout matching to protect against malicious or very slow clients err := cx.Conn.SetReadDeadline(deadline) if err != nil { return err } - - for i := 0; i < 10000; i++ { // retry prefetching and matching routes until timeout - - // Do not call prefetch if this is the first loop iteration and there already is some data available, - // since this means we are at the start of a subroute handler and previous prefetch calls likely already fetched all bytes available from the client. - // Which means it would block the subroute handler. In the second iteration (if no subroute routes match) blocking is the correct behaviour. - if i != 0 || cx.buf == nil || len(cx.buf[cx.offset:]) == 0 { + for { + // only read more because matchers require more (no matcher in the simplest case). + // can happen if this routes list is embedded in another + if matcherNeedMore { err = cx.prefetch() if err != nil { logFunc := logger.Error @@ -128,18 +141,22 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio } for i, route := range routes { - // After a match continue with the routes after the matched one, instead of starting at the beginning. - // This is done for backwards compatibility with configs written before the "Non blocking matchers & matching timeout" rewrite. - // See https://github.com/mholt/caddy-l4/pull/192 and https://github.com/mholt/caddy-l4/pull/192#issuecomment-2143681952. if i <= lastMatchedRouteIdx { continue } - // Only skip once after a match, so it behaves like we continued after the match. - lastMatchedRouteIdx = -1 + + // If the route is definitely not matched, skip it + if s, ok := routesStatus[i]; ok && s == routeNotMatched && i <= lastNeedsMoreIdx { + continue + } + // now the matcher is after a matched route and current route needs more data to determine if more data is needed. + // note a matcher is skipped if the one after it can determine it is matched // A route must match at least one of the matcher sets matched, err := route.matcherSets.AnyMatch(cx) if errors.Is(err, ErrConsumedAllPrefetchedBytes) { + lastNeedsMoreIdx = i + routesStatus[i] = routeNeedsMore continue // ignore and try next route } if err != nil { @@ -147,6 +164,9 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio return nil } if matched { + routesStatus[i] = routeMatched + lastMatchedRouteIdx = i + lastNeedsMoreIdx = i // remove deadline after we matched err = cx.Conn.SetReadDeadline(time.Time{}) if err != nil { @@ -163,7 +183,7 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio return nil }) // compile the route handler stack with lastHandler being called last - handler := wrapHandler(next)(lastHandler) + handler := wrapHandler(forwardNextHandler{})(lastHandler) for i := len(route.middleware) - 1; i >= 0; i-- { handler = route.middleware[i](handler) } @@ -173,18 +193,31 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio } // If handler is terminal we stop routing, - // otherwise we jump back to the start of the routing loop to peel of more protocol layers. + // otherwise we try the next handler. if isTerminal { return nil - } else { - lastMatchedRouteIdx = i - goto router } + } else { + routesStatus[i] = routeNotMatched } } + // end of match + if lastMatchedRouteIdx == len(routes)-1 { + // next is called because if the last handler is terminal, it's already returned + return next.Handle(cx) + } + var indetermined int + for i, s := range routesStatus { + if i > lastMatchedRouteIdx && s == routeNeedsMore { + indetermined++ + } + } + // some of the matchers can't reach a conclusion + if indetermined > 0 { + matcherNeedMore = true + goto loop + } + return next.Handle(cx) } - - logger.Error("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(errors.New("number of prefetch calls exhausted"))) - return nil }) } diff --git a/layer4/routes_test.go b/layer4/routes_test.go index 9024756..a66558c 100644 --- a/layer4/routes_test.go +++ b/layer4/routes_test.go @@ -2,7 +2,10 @@ package layer4 import ( "context" + "encoding/json" "errors" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "io" "net" "testing" "time" @@ -13,11 +16,33 @@ import ( "go.uber.org/zap/zaptest/observer" ) +type testIoMatcher struct { +} + +func (testIoMatcher) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.matchers.testIoMatcher", + New: func() caddy.Module { return new(testIoMatcher) }, + } +} + +func (m *testIoMatcher) Match(cx *Connection) (bool, error) { + buf := make([]byte, 1) + n, err := io.ReadFull(cx, buf) + return n > 0, err +} + func TestMatchingTimeoutWorks(t *testing.T) { ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) defer cancel() - routes := RouteList{&Route{}} + caddy.RegisterModule(testIoMatcher{}) + + routes := RouteList{&Route{ + MatcherSetsRaw: caddyhttp.RawMatcherSets{ + caddy.ModuleMap{"testIoMatcher": json.RawMessage("{}")}, // any io using matcher + }, + }} err := routes.Provision(ctx) if err != nil { @@ -27,9 +52,9 @@ func TestMatchingTimeoutWorks(t *testing.T) { matched := false loggerCore, logs := observer.New(zapcore.WarnLevel) compiledRoutes := routes.Compile(zap.New(loggerCore), 5*time.Millisecond, - NextHandlerFunc(func(con *Connection, next Handler) error { + HandlerFunc(func(con *Connection) error { matched = true - return next.Handle(con) + return nil })) in, out := net.Pipe() diff --git a/layer4/server.go b/layer4/server.go index 092f501..f599ca9 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -68,7 +68,7 @@ func (s *Server) Provision(ctx caddy.Context, logger *zap.Logger) error { if err != nil { return err } - s.compiledRoute = s.Routes.Compile(s.logger, time.Duration(s.MatchingTimeout), nopNextHandler{}) + s.compiledRoute = s.Routes.Compile(s.logger, time.Duration(s.MatchingTimeout), nopHandler{}) return nil } diff --git a/modules/l4http/httpmatcher.go b/modules/l4http/httpmatcher.go index ffbd54c..9935755 100644 --- a/modules/l4http/httpmatcher.go +++ b/modules/l4http/httpmatcher.go @@ -84,7 +84,14 @@ func (m *MatchHTTP) Match(cx *layer4.Connection) (bool, error) { var err error data := cx.MatchingBytes() - if !m.isHttp(data) { + needMore, matched := m.isHttp(data) + if needMore { + if len(data) >= layer4.MaxMatchingBytes { + return false, layer4.ErrMatchingBufferFull + } + return false, layer4.ErrConsumedAllPrefetchedBytes + } + if !matched { return false, nil } @@ -124,11 +131,13 @@ func (m *MatchHTTP) Match(cx *layer4.Connection) (bool, error) { return m.matcherSets.AnyMatch(req), nil } -func (m *MatchHTTP) isHttp(data []byte) bool { +// isHttp test if the buffered data looks like HTTP by looking at the first line. +// first boolean determines if more data is required +func (m MatchHTTP) isHttp(data []byte) (bool, bool) { // try to find the end of a http request line, for example " HTTP/1.1\r\n" i := bytes.IndexByte(data, 0x0a) // find first new line if i < 10 { - return false + return true, false } // assume only \n line ending start := i - 9 // position of space in front of HTTP @@ -138,7 +147,7 @@ func (m *MatchHTTP) isHttp(data []byte) bool { start -= 1 end -= 1 } - return bytes.Compare(data[start:end], []byte(" HTTP/")) == 0 + return false, bytes.Compare(data[start:end], []byte(" HTTP/")) == 0 } // Parses information from a http2 request with prior knowledge (RFC 7540 Section 3.4) diff --git a/modules/l4http/httpmatcher_test.go b/modules/l4http/httpmatcher_test.go index 7028436..ce63d3d 100644 --- a/modules/l4http/httpmatcher_test.go +++ b/modules/l4http/httpmatcher_test.go @@ -23,6 +23,28 @@ func assertNoError(t *testing.T, err error) { } } +// testHandler is a connection handler that will set a variable to let us know it was called. +type testHandler struct { +} + +// CaddyModule returns the Caddy module information. +func (testHandler) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.handlers.test_handler", + New: func() caddy.Module { return new(testHandler) }, + } +} + +// Handle handles the connections. +func (h *testHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { + cx.SetVar("test_handler_called", true) + return next.Handle(cx) +} + +func init() { + caddy.RegisterModule(testHandler{}) +} + func httpMatchTester(t *testing.T, matchers json.RawMessage, data []byte) (bool, error) { in, out := net.Pipe() defer func() { _ = in.Close() }() @@ -41,14 +63,15 @@ func httpMatchTester(t *testing.T, matchers json.RawMessage, data []byte) (bool, MatcherSetsRaw: caddyhttp.RawMatcherSets{ caddy.ModuleMap{"http": matchers}, }, + HandlersRaw: []json.RawMessage{json.RawMessage("{\"handler\":\"test_handler\"}")}, }} err := routes.Provision(ctx) assertNoError(t, err) matched := false compiledRoute := routes.Compile(zap.NewNop(), 10*time.Millisecond, - layer4.NextHandlerFunc(func(con *layer4.Connection, _ layer4.Handler) error { - matched = true + layer4.HandlerFunc(func(con *layer4.Connection) error { + matched = con.GetVar("test_handler_called") != nil return nil })) @@ -206,7 +229,7 @@ func TestHttpMatchingByProtocolWithHttps(t *testing.T) { handlerCalled := false compiledRoute := routes.Compile(zap.NewNop(), 100*time.Millisecond, - layer4.NextHandlerFunc(func(con *layer4.Connection, _ layer4.Handler) error { + layer4.HandlerFunc(func(con *layer4.Connection) error { handlerCalled = true return nil })) @@ -291,7 +314,7 @@ func TestMatchHTTP_isHttp(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - matched := (&MatchHTTP{}).isHttp(tc.data) + _, matched := MatchHTTP{}.isHttp(tc.data) if matched != tc.shouldMatch { t.Fatalf("test %v | matched: %v != shouldMatch: %v", tc.name, matched, tc.shouldMatch) } diff --git a/modules/l4ssh/matcher.go b/modules/l4ssh/matcher.go index 3fb5e17..451bfec 100644 --- a/modules/l4ssh/matcher.go +++ b/modules/l4ssh/matcher.go @@ -42,9 +42,9 @@ func (*MatchSSH) CaddyModule() caddy.ModuleInfo { // Match returns true if the connection looks like SSH. func (m *MatchSSH) Match(cx *layer4.Connection) (bool, error) { p := make([]byte, len(sshPrefix)) - n, err := io.ReadFull(cx, p) - if err != nil || n < len(sshPrefix) { - return false, nil + _, err := io.ReadFull(cx, p) + if err != nil { + return false, err } return bytes.Equal(p, sshPrefix), nil } diff --git a/modules/l4subroute/handler.go b/modules/l4subroute/handler.go index b36c5e2..643792b 100644 --- a/modules/l4subroute/handler.go +++ b/modules/l4subroute/handler.go @@ -70,10 +70,7 @@ func (h *Handler) Provision(ctx caddy.Context) error { // Handle handles the connections. func (h *Handler) Handle(cx *layer4.Connection, next layer4.Handler) error { - subroute := h.Routes.Compile(h.logger, time.Duration(h.MatchingTimeout), - layer4.NextHandlerFunc(func(cx *layer4.Connection, _ layer4.Handler) error { - return next.Handle(cx) // continue with original chain after subroute - })) + subroute := h.Routes.Compile(h.logger, time.Duration(h.MatchingTimeout), next) return subroute.Handle(cx) } diff --git a/modules/l4xmpp/matcher.go b/modules/l4xmpp/matcher.go index e930f24..ec14ab3 100644 --- a/modules/l4xmpp/matcher.go +++ b/modules/l4xmpp/matcher.go @@ -42,9 +42,9 @@ func (*MatchXMPP) CaddyModule() caddy.ModuleInfo { // Match returns true if the connection looks like XMPP. func (m *MatchXMPP) Match(cx *layer4.Connection) (bool, error) { p := make([]byte, minXmppLength) - n, err := io.ReadFull(cx, p) - if err != nil || n < minXmppLength { // needs at least 50 (fix for adium/pidgin) - return false, nil + _, err := io.ReadFull(cx, p) + if err != nil { // needs at least 50 (fix for adium/pidgin) + return false, err } return strings.Contains(string(p), xmppWord), nil }