Skip to content

Commit

Permalink
fix: router pattern (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
a179346 authored Jan 12, 2025
1 parent 02de2bc commit 46cf117
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 4 deletions.
9 changes: 5 additions & 4 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ func (r *Router) SubRouter(pattern string) *Router {
if err != nil {
panic(err)
}
if len(pat.path) > 0 && !strings.HasSuffix(pat.path, "/") {
pat.path += "/"
}

subRouter := &Router{
pattern: pat,
Expand Down Expand Up @@ -101,7 +98,11 @@ func (r *Router) getHttpHandlerMap() *httpHandlerMap {

func (r *Router) setupHttpHandlerMap(targetHttpHandlerMap *httpHandlerMap) {
localHandlerMap := newHttpHandlerMap()
localHandlerMap.addPatternString(r.pattern.String())
routerPatternString := r.pattern.String()
if !strings.HasSuffix(routerPatternString, "/") {
routerPatternString += "/"
}
localHandlerMap.addPatternString(routerPatternString)

for _, h := range r.handlerRepository.getHandlers() {
if !h.all && h.owner == r {
Expand Down
277 changes: 277 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,280 @@ func TestNestedRouter(t *testing.T) {
})
}
}

func TestSubRouter(t *testing.T) {
type result struct {
status int
test bool
getall bool
getid bool
post bool
postid bool
testnotfound bool
notfound bool
text string
}
type testcase struct {
method string
url string
}

testcases := map[testcase]result{
{http.MethodGet, "/"}: {
status: http.StatusNotFound,
test: false,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: false,
notfound: true,
text: "notfound",
},
{http.MethodGet, "/abc"}: {
status: http.StatusNotFound,
test: false,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: false,
notfound: true,
text: "notfound",
},
{http.MethodGet, "/test"}: {
status: http.StatusOK,
test: true,
getall: true,
getid: false,
post: false,
postid: false,
testnotfound: false,
notfound: false,
text: "getall",
},
{http.MethodGet, "/test/"}: {
status: http.StatusNotFound,
test: true,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: true,
notfound: false,
text: "testnotfound",
},
{http.MethodGet, "/test/123"}: {
status: http.StatusOK,
test: true,
getall: false,
getid: true,
post: false,
postid: false,
testnotfound: false,
notfound: false,
text: "123",
},
{http.MethodGet, "/test/123/"}: {
status: http.StatusNotFound,
test: true,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: true,
notfound: false,
text: "testnotfound",
},
{http.MethodGet, "/test/123/456"}: {
status: http.StatusNotFound,
test: true,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: true,
notfound: false,
text: "testnotfound",
},
{http.MethodPost, "/"}: {
status: http.StatusNotFound,
test: false,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: false,
notfound: true,
text: "notfound",
},
{http.MethodPost, "/abc"}: {
status: http.StatusNotFound,
test: false,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: false,
notfound: true,
text: "notfound",
},
{http.MethodPost, "/test"}: {
status: http.StatusOK,
test: true,
getall: false,
getid: false,
post: true,
postid: false,
testnotfound: false,
notfound: false,
text: "post",
},
{http.MethodPost, "/test/"}: {
status: http.StatusNotFound,
test: true,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: true,
notfound: false,
text: "testnotfound",
},
{http.MethodPost, "/test/123"}: {
status: http.StatusOK,
test: true,
getall: false,
getid: false,
post: false,
postid: true,
testnotfound: false,
notfound: false,
text: "post123",
},
{http.MethodPost, "/test/123/"}: {
status: http.StatusNotFound,
test: true,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: true,
notfound: false,
text: "testnotfound",
},
{http.MethodPost, "/test/123/456"}: {
status: http.StatusNotFound,
test: true,
getall: false,
getid: false,
post: false,
postid: false,
testnotfound: true,
notfound: false,
text: "testnotfound",
},
}

router := New()

testRouter := router.SubRouter("/test")
testRouter.Use(func(c *Context) Response {
c.ResHeader().Set("test", "1")
return c.Next()
})

testRouter.Handle("GET /", func(c *Context) Response {
c.ResHeader().Set("getall", "1")
return textResponse{
http.StatusOK,
"getall",
}
})

testRouter.Handle("GET /{id}", func(c *Context) Response {
c.ResHeader().Set("getid", "1")
return textResponse{
http.StatusOK,
c.Req.PathValue("id"),
}
})

testRouter.Handle("POST /", func(c *Context) Response {
c.ResHeader().Set("post", "1")
return textResponse{
http.StatusOK,
"post",
}
})

testRouter.Handle("POST /{id}", func(c *Context) Response {
c.ResHeader().Set("postid", "1")
return textResponse{
http.StatusOK,
"post" + c.Req.PathValue("id"),
}
})

testRouter.Use(func(c *Context) Response {
c.ResHeader().Set("testnotfound", "1")
return textResponse{
http.StatusNotFound,
"testnotfound",
}
})

router.Use(func(c *Context) Response {
c.ResHeader().Set("notfound", "1")
return textResponse{
http.StatusNotFound,
"notfound",
}
})

mux := router.CreateServeMux()

assertHeader := func(res *http.Response, key string, exist bool) {
if exist {
if got, want := res.Header.Get(key), "1"; got != want {
t.Errorf("header %s error. got:\"%v\" want:\"%v\"", key, got, want)
}
} else {
if got, want := res.Header.Get(key), ""; got != want {
t.Errorf("header %s error. got:\"%v\" want:\"%v\"", key, got, want)
}
}
}

for testcase, result := range testcases {
name := fmt.Sprintf("Test SubRouter: %s %s", testcase.method, testcase.url)
t.Run(name, func(t *testing.T) {
req := httptest.NewRequest(testcase.method, testcase.url, nil)
w := httptest.NewRecorder()

mux.ServeHTTP(w, req)

res := w.Result()
defer res.Body.Close()

data, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("unexpected error %v", err)
}

if got, want := res.StatusCode, result.status; got != want {
t.Errorf("status error. got:\"%v\" want:\"%v\"", got, want)
}
if got, want := string(data), result.text; got != want {
t.Errorf("text error. got:\"%v\" want:\"%v\"", got, want)
}
assertHeader(res, "test", result.test)
assertHeader(res, "getall", result.getall)
assertHeader(res, "getid", result.getid)
assertHeader(res, "post", result.post)
assertHeader(res, "postid", result.postid)
assertHeader(res, "testnotfound", result.testnotfound)
assertHeader(res, "notfound", result.notfound)
})
}
}

0 comments on commit 46cf117

Please sign in to comment.