diff --git a/server/handler.go b/server/handler.go index 9a0be8c..81d54a1 100755 --- a/server/handler.go +++ b/server/handler.go @@ -3,6 +3,7 @@ package server import ( "context" "net/http" + "strings" "time" "github.com/go-oauth2/oauth2/v4" @@ -49,8 +50,14 @@ type ( // ExtensionFieldsHandler in response to the access token with the extension of the field ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) - // ResponseTokenHandler response token handing + // ResponseTokenHandler response token handling ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error + + // Handler to fetch the refresh token from the request + RefreshTokenResolveHandler func(r *http.Request) (string, error) + + // Handler to fetch the access token from the request + AccessTokenResolveHandler func(r *http.Request) (string, bool) ) // ClientFormHandler get client data from form @@ -71,3 +78,44 @@ func ClientBasicHandler(r *http.Request) (string, string, error) { } return username, password, nil } + +func RefreshTokenFormResolveHandler(r *http.Request) (string, error) { + rt := r.FormValue("refresh_token") + if rt == "" { + return "", errors.ErrInvalidRequest + } + + return rt, nil +} + +func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) { + c, err := r.Cookie("refresh_token") + if err != nil { + return "", errors.ErrInvalidRequest + } + + return c.Value, nil +} + +func AccessTokenDefaultResolveHandler(r *http.Request) (string, bool) { + token := "" + auth := r.Header.Get("Authorization") + prefix := "Bearer " + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +func AccessTokenCookieResolveHandler(r *http.Request) (string, bool) { + c, err := r.Cookie("access_token") + if err != nil { + return "", false + } + + return c.Value, true +} diff --git a/server/handler_test.go b/server/handler_test.go new file mode 100644 index 0000000..dac0337 --- /dev/null +++ b/server/handler_test.go @@ -0,0 +1,121 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/go-oauth2/oauth2/v4/errors" + . "github.com/smartystreets/goconvey/convey" +) + +func TestRefreshTokenFormResolveHandler(t *testing.T) { + Convey("Correct Request", t, func() { + f := url.Values{} + f.Add("refresh_token", "test_token") + + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, err := RefreshTokenFormResolveHandler(r) + So(err, ShouldBeNil) + So(token, ShouldEqual, "test_token") + }) + + Convey("Missing Refresh Token", t, func() { + r := httptest.NewRequest("POST", "/", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, err := RefreshTokenFormResolveHandler(r) + So(err, ShouldBeError, errors.ErrInvalidRequest) + So(token, ShouldBeEmpty) + }) +} + +func TestRefreshTokenCookieResolveHandler(t *testing.T) { + Convey("Correct Request", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.AddCookie(&http.Cookie{ + Name: "refresh_token", + Value: "test_token", + HttpOnly: true, + Path: "/", + Domain: ".example.com", + Expires: time.Now().Add(time.Hour), + }) + + token, err := RefreshTokenCookieResolveHandler(r) + So(err, ShouldBeNil) + So(token, ShouldEqual, "test_token") + }) + + Convey("Missing Refresh Token", t, func() { + r := httptest.NewRequest("POST", "/", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, err := RefreshTokenCookieResolveHandler(r) + So(err, ShouldBeError, errors.ErrInvalidRequest) + So(token, ShouldBeEmpty) + }) +} + +func TestAccessTokenDefaultHandler(t *testing.T) { + Convey("Request Has Header", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.Header.Add("Authorization", "Bearer test_token") + + token, ok := AccessTokenDefaultResolveHandler(r) + So(ok, ShouldBeTrue) + So(token, ShouldEqual, "test_token") + }) + + Convey("Request Has FormValue", t, func() { + f := url.Values{} + f.Add("access_token", "test_token") + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, ok := AccessTokenDefaultResolveHandler(r) + So(ok, ShouldBeTrue) + So(token, ShouldEqual, "test_token") + }) + + Convey("Request Has Nothing", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + + token, ok := AccessTokenDefaultResolveHandler(r) + So(ok, ShouldBeFalse) + So(token, ShouldBeEmpty) + }) +} + +func TestAccessTokenCookieHandler(t *testing.T) { + Convey("Request Has Cookie", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.AddCookie(&http.Cookie{ + Name: "access_token", + Value: "test_token", + HttpOnly: true, + Path: "/", + Domain: ".example.com", + Expires: time.Now().Add(time.Hour), + }) + + token, ok := AccessTokenCookieResolveHandler(r) + So(ok, ShouldBeTrue) + So(token, ShouldEqual, "test_token") + }) + + Convey("Request Has No Cookie", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + + token, ok := AccessTokenCookieResolveHandler(r) + So(ok, ShouldBeFalse) + So(token, ShouldBeEmpty) + }) +} diff --git a/server/server.go b/server/server.go index df19d1f..f4dba2d 100755 --- a/server/server.go +++ b/server/server.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "time" "github.com/go-oauth2/oauth2/v4" @@ -25,8 +24,10 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server { Manager: manager, } - // default handler + // default handlers srv.ClientInfoHandler = ClientBasicHandler + srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler + srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied @@ -56,6 +57,8 @@ type Server struct { AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler ResponseTokenHandler ResponseTokenHandler + RefreshTokenResolveHandler RefreshTokenResolveHandler + AccessTokenResolveHandler AccessTokenResolveHandler } func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { @@ -367,10 +370,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") + tgr.Refresh, err = s.RefreshTokenResolveHandler(r) tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest + if err != nil { + return "", nil, err } } return gt, tgr, nil @@ -569,27 +572,12 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head return data, statusCode, re.Header } -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() - accessToken, ok := s.BearerAuth(r) + accessToken, ok := s.AccessTokenResolveHandler(r) if !ok { return nil, errors.ErrInvalidAccessToken } diff --git a/server/server_config.go b/server/server_config.go index d4f0404..660da8d 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -93,3 +93,13 @@ func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { func (s *Server) SetResponseTokenHandler(handler ResponseTokenHandler) { s.ResponseTokenHandler = handler } + +// SetRefreshTokenResolveHandler refresh token resolver +func (s *Server) SetRefreshTokenResolveHandler(handler RefreshTokenResolveHandler) { + s.RefreshTokenResolveHandler = handler +} + +// SetAccessTokenResolveHandler access token resolver +func (s *Server) SetAccessTokenResolveHandler(handler AccessTokenResolveHandler) { + s.AccessTokenResolveHandler = handler +}