From 9b85d4722a8760f2e2aec5e16ba4cc2ddc9898bc Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Sun, 12 Jan 2025 21:32:40 -0500 Subject: [PATCH] fix: https://github.com/gorilla/mux/pull/517 --- mux_test.go | 157 ++++++++++++++++++++++++++++++++++++++++++++++++++++ route.go | 9 ++- 2 files changed, 163 insertions(+), 3 deletions(-) diff --git a/mux_test.go b/mux_test.go index bac758bc..1b01721a 100644 --- a/mux_test.go +++ b/mux_test.go @@ -3137,6 +3137,163 @@ func BenchmarkPopulateContext(b *testing.B) { } } +// testOptionsMiddleWare returns 200 on an OPTIONS request +func testOptionsMiddleWare(inner http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + inner.ServeHTTP(w, r) + }) +} + +// TestRouterOrder Should Pass whichever order route is defined +func TestRouterOrder(t *testing.T) { + type requestCase struct { + request *http.Request + expCode int + } + + tests := []struct { + name string + routes []*Route + customMiddleware MiddlewareFunc + requests []requestCase + }{ + { + name: "Routes added with same method and intersecting path regex", + routes: []*Route{ + new(Route).Path("/a/b").Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })).Methods(http.MethodGet), + new(Route).Path("/a/{a}").Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).Methods(http.MethodGet), + }, + requests: []requestCase{ + { + request: newRequest(http.MethodGet, "/a/b"), + expCode: http.StatusNotFound, + }, + { + request: newRequest(http.MethodGet, "/a/a"), + expCode: http.StatusOK, + }, + }, + }, + { + name: "Routes added with same method and intersecting path regex, path with pathVariable first", + routes: []*Route{ + new(Route).Path("/a/{a}").Handler(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).Methods(http.MethodGet), + new(Route).Path("/a/b").Handler(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })).Methods(http.MethodGet), + }, + requests: []requestCase{ + { + request: newRequest(http.MethodGet, "/a/b"), + expCode: http.StatusOK, + }, + { + request: newRequest(http.MethodGet, "/a/a"), + expCode: http.StatusOK, + }, + }, + }, + { + name: "Routes added same path - different methods, no path variables", + routes: []*Route{ + new(Route).Path("/a/b").Handler(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).Methods(http.MethodGet), + new(Route).Path("/a/b").Handler(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })).Methods(http.MethodOptions), + }, + requests: []requestCase{ + { + request: newRequest(http.MethodGet, "/a/b"), + expCode: http.StatusOK, + }, + { + request: newRequest(http.MethodOptions, "/a/b"), + expCode: http.StatusNotFound, + }, + }, + }, + { + name: "Routes added same path - different methods, with path variables and middleware", + routes: []*Route{ + new(Route).Path("/a/{a}").Handler(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })).Methods(http.MethodGet), + new(Route).Path("/a/b").Handler(nil).Methods(http.MethodOptions), + }, + customMiddleware: testOptionsMiddleWare, + requests: []requestCase{ + { + request: newRequest(http.MethodGet, "/a/b"), + expCode: http.StatusNotFound, + }, + { + request: newRequest(http.MethodOptions, "/a/b"), + expCode: http.StatusOK, + }, + }, + }, + { + name: "Routes added same path - different methods, with path variables and middleware order reversed", + routes: []*Route{ + new(Route).Path("/a/b").Handler(nil).Methods(http.MethodOptions), + new(Route).Path("/a/{a}").Handler(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })).Methods(http.MethodGet), + }, + customMiddleware: testOptionsMiddleWare, + requests: []requestCase{ + { + request: newRequest(http.MethodGet, "/a/b"), + expCode: http.StatusNotFound, + }, + { + request: newRequest(http.MethodOptions, "/a/b"), + expCode: http.StatusOK, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + router := NewRouter() + + if test.customMiddleware != nil { + router.Use(test.customMiddleware) + } + + router.routes = test.routes + w := NewRecorder() + + for _, requestCase := range test.requests { + router.ServeHTTP(w, requestCase.request) + + if w.Code != requestCase.expCode { + t.Fatalf("Expected status code %d (got %d)", requestCase.expCode, w.Code) + } + } + }) + } +} + // mapToPairs converts a string map to a slice of string pairs func mapToPairs(m map[string]string) []string { var i int diff --git a/route.go b/route.go index 4eba5343..cee0a44c 100644 --- a/route.go +++ b/route.go @@ -103,11 +103,14 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { return false } - if match.MatchErr != nil && r.handler != nil { + // If a route matches, but the HTTP method does not, we do one of two (2) things: + // 1. Reset the match error if we find a matching method later. + // 2. Else, we override the matched handler in the event we have a possible fallback handler for that route. + // + // This prevents propagation of ErrMethodMismatch once a suitable match is found for a Method-Path combination + if match.MatchErr == ErrMethodMismatch { // We found a route which matches request method, clear MatchErr match.MatchErr = nil - // Then override the mis-matched handler - match.Handler = r.handler } // Yay, we have a match. Let's collect some info about it.