From e22fe90f76810acc5e2199048d2a55f64ca1747e Mon Sep 17 00:00:00 2001 From: Robert Kuo Date: Sun, 12 Jan 2025 17:40:57 +0800 Subject: [PATCH] router pattern --- router.go | 9 +- router_test.go | 277 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+), 4 deletions(-) diff --git a/router.go b/router.go index f28b69d..e75dcb7 100644 --- a/router.go +++ b/router.go @@ -56,9 +56,6 @@ func (r *Router) SubRouter(pattern string) *Router { if err != nil { panic(err) } - if len(pat.path) > 0 && !strings.HasSuffix(pat.path, "/") { - pat.path += "/" - } subRouter := &Router{ pattern: pat, @@ -101,7 +98,11 @@ func (r *Router) getHttpHandlerMap() *httpHandlerMap { func (r *Router) setupHttpHandlerMap(targetHttpHandlerMap *httpHandlerMap) { localHandlerMap := newHttpHandlerMap() - localHandlerMap.addPatternString(r.pattern.String()) + routerPatternString := r.pattern.String() + if !strings.HasSuffix(routerPatternString, "/") { + routerPatternString += "/" + } + localHandlerMap.addPatternString(routerPatternString) for _, h := range r.handlerRepository.getHandlers() { if !h.all && h.owner == r { diff --git a/router_test.go b/router_test.go index 33aa59a..0606ebb 100644 --- a/router_test.go +++ b/router_test.go @@ -379,3 +379,280 @@ func TestNestedRouter(t *testing.T) { }) } } + +func TestSubRouter(t *testing.T) { + type result struct { + status int + test bool + getall bool + getid bool + post bool + postid bool + testnotfound bool + notfound bool + text string + } + type testcase struct { + method string + url string + } + + testcases := map[testcase]result{ + {http.MethodGet, "/"}: { + status: http.StatusNotFound, + test: false, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: false, + notfound: true, + text: "notfound", + }, + {http.MethodGet, "/abc"}: { + status: http.StatusNotFound, + test: false, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: false, + notfound: true, + text: "notfound", + }, + {http.MethodGet, "/test"}: { + status: http.StatusOK, + test: true, + getall: true, + getid: false, + post: false, + postid: false, + testnotfound: false, + notfound: false, + text: "getall", + }, + {http.MethodGet, "/test/"}: { + status: http.StatusNotFound, + test: true, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: true, + notfound: false, + text: "testnotfound", + }, + {http.MethodGet, "/test/123"}: { + status: http.StatusOK, + test: true, + getall: false, + getid: true, + post: false, + postid: false, + testnotfound: false, + notfound: false, + text: "123", + }, + {http.MethodGet, "/test/123/"}: { + status: http.StatusNotFound, + test: true, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: true, + notfound: false, + text: "testnotfound", + }, + {http.MethodGet, "/test/123/456"}: { + status: http.StatusNotFound, + test: true, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: true, + notfound: false, + text: "testnotfound", + }, + {http.MethodPost, "/"}: { + status: http.StatusNotFound, + test: false, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: false, + notfound: true, + text: "notfound", + }, + {http.MethodPost, "/abc"}: { + status: http.StatusNotFound, + test: false, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: false, + notfound: true, + text: "notfound", + }, + {http.MethodPost, "/test"}: { + status: http.StatusOK, + test: true, + getall: false, + getid: false, + post: true, + postid: false, + testnotfound: false, + notfound: false, + text: "post", + }, + {http.MethodPost, "/test/"}: { + status: http.StatusNotFound, + test: true, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: true, + notfound: false, + text: "testnotfound", + }, + {http.MethodPost, "/test/123"}: { + status: http.StatusOK, + test: true, + getall: false, + getid: false, + post: false, + postid: true, + testnotfound: false, + notfound: false, + text: "post123", + }, + {http.MethodPost, "/test/123/"}: { + status: http.StatusNotFound, + test: true, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: true, + notfound: false, + text: "testnotfound", + }, + {http.MethodPost, "/test/123/456"}: { + status: http.StatusNotFound, + test: true, + getall: false, + getid: false, + post: false, + postid: false, + testnotfound: true, + notfound: false, + text: "testnotfound", + }, + } + + router := New() + + testRouter := router.SubRouter("/test") + testRouter.Use(func(c *Context) Response { + c.ResHeader().Set("test", "1") + return c.Next() + }) + + testRouter.Handle("GET /", func(c *Context) Response { + c.ResHeader().Set("getall", "1") + return textResponse{ + http.StatusOK, + "getall", + } + }) + + testRouter.Handle("GET /{id}", func(c *Context) Response { + c.ResHeader().Set("getid", "1") + return textResponse{ + http.StatusOK, + c.Req.PathValue("id"), + } + }) + + testRouter.Handle("POST /", func(c *Context) Response { + c.ResHeader().Set("post", "1") + return textResponse{ + http.StatusOK, + "post", + } + }) + + testRouter.Handle("POST /{id}", func(c *Context) Response { + c.ResHeader().Set("postid", "1") + return textResponse{ + http.StatusOK, + "post" + c.Req.PathValue("id"), + } + }) + + testRouter.Use(func(c *Context) Response { + c.ResHeader().Set("testnotfound", "1") + return textResponse{ + http.StatusNotFound, + "testnotfound", + } + }) + + router.Use(func(c *Context) Response { + c.ResHeader().Set("notfound", "1") + return textResponse{ + http.StatusNotFound, + "notfound", + } + }) + + mux := router.CreateServeMux() + + assertHeader := func(res *http.Response, key string, exist bool) { + if exist { + if got, want := res.Header.Get(key), "1"; got != want { + t.Errorf("header %s error. got:\"%v\" want:\"%v\"", key, got, want) + } + } else { + if got, want := res.Header.Get(key), ""; got != want { + t.Errorf("header %s error. got:\"%v\" want:\"%v\"", key, got, want) + } + } + } + + for testcase, result := range testcases { + name := fmt.Sprintf("Test SubRouter: %s %s", testcase.method, testcase.url) + t.Run(name, func(t *testing.T) { + req := httptest.NewRequest(testcase.method, testcase.url, nil) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + res := w.Result() + defer res.Body.Close() + + data, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("unexpected error %v", err) + } + + if got, want := res.StatusCode, result.status; got != want { + t.Errorf("status error. got:\"%v\" want:\"%v\"", got, want) + } + if got, want := string(data), result.text; got != want { + t.Errorf("text error. got:\"%v\" want:\"%v\"", got, want) + } + assertHeader(res, "test", result.test) + assertHeader(res, "getall", result.getall) + assertHeader(res, "getid", result.getid) + assertHeader(res, "post", result.post) + assertHeader(res, "postid", result.postid) + assertHeader(res, "testnotfound", result.testnotfound) + assertHeader(res, "notfound", result.notfound) + }) + } +}