From 462689916748c0a702f9feef7c4d2f00758a431b Mon Sep 17 00:00:00 2001 From: attilabanga Date: Fri, 21 Mar 2025 13:02:46 +0100 Subject: [PATCH 1/3] fix typo --- server/handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/handler.go b/server/handler.go index 9a0be8c..ed77e70 100755 --- a/server/handler.go +++ b/server/handler.go @@ -49,7 +49,7 @@ type ( // ExtensionFieldsHandler in response to the access token with the extension of the field ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) - // ResponseTokenHandler response token handing + // ResponseTokenHandler response token handling ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error ) From c36e1f85c289c25be39be71af14074df97b17d32 Mon Sep 17 00:00:00 2001 From: attilabanga Date: Fri, 21 Mar 2025 14:34:19 +0100 Subject: [PATCH 2/3] get the refresh token from the request with a customizable handler function --- server/handler.go | 21 ++++++++++++++ server/handler_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++ server/server.go | 10 ++++--- 3 files changed, 91 insertions(+), 4 deletions(-) create mode 100644 server/handler_test.go diff --git a/server/handler.go b/server/handler.go index ed77e70..a3ea93d 100755 --- a/server/handler.go +++ b/server/handler.go @@ -51,6 +51,9 @@ type ( // ResponseTokenHandler response token handling ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error + + // Handler to fetch the refresh token from the request + RefreshTokenResolveHandler func(r *http.Request) (string, error) ) // ClientFormHandler get client data from form @@ -71,3 +74,21 @@ func ClientBasicHandler(r *http.Request) (string, string, error) { } return username, password, nil } + +func RefreshTokenFormResolveHandler(r *http.Request) (string, error) { + rt := r.FormValue("refresh_token") + if rt == "" { + return "", errors.ErrInvalidRequest + } + + return rt, nil +} + +func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) { + c, err := r.Cookie("refresh_token") + if err != nil { + return "", errors.ErrInvalidRequest + } + + return c.Value, nil +} diff --git a/server/handler_test.go b/server/handler_test.go new file mode 100644 index 0000000..e771556 --- /dev/null +++ b/server/handler_test.go @@ -0,0 +1,64 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/go-oauth2/oauth2/v4/errors" + . "github.com/smartystreets/goconvey/convey" +) + +func TestRefreshTokenFormResolveHandler(t *testing.T) { + Convey("Correct Request", t, func() { + f := url.Values{} + f.Add("refresh_token", "test_token") + + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, err := RefreshTokenFormResolveHandler(r) + So(err, ShouldBeNil) + So(token, ShouldEqual, "test_token") + }) + + Convey("Missing Refresh Token", t, func() { + r := httptest.NewRequest("POST", "/", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, err := RefreshTokenFormResolveHandler(r) + So(err, ShouldBeError, errors.ErrInvalidRequest) + So(token, ShouldBeEmpty) + }) +} + +func TestRefreshTokenCookieResolveHandler(t *testing.T) { + Convey("Correct Request", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.AddCookie(&http.Cookie{ + Name: "refresh_token", + Value: "test_token", + HttpOnly: true, + Path: "/", + Domain: ".example.com", + Expires: time.Now().Add(time.Hour), + }) + + token, err := RefreshTokenCookieResolveHandler(r) + So(err, ShouldBeNil) + So(token, ShouldEqual, "test_token") + }) + + Convey("Missing Refresh Token", t, func() { + r := httptest.NewRequest("POST", "/", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, err := RefreshTokenCookieResolveHandler(r) + So(err, ShouldBeError, errors.ErrInvalidRequest) + So(token, ShouldBeEmpty) + }) +} diff --git a/server/server.go b/server/server.go index df19d1f..64480fe 100755 --- a/server/server.go +++ b/server/server.go @@ -25,8 +25,9 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server { Manager: manager, } - // default handler + // default handlers srv.ClientInfoHandler = ClientBasicHandler + srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied @@ -56,6 +57,7 @@ type Server struct { AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler ResponseTokenHandler ResponseTokenHandler + RefreshTokenResolveHandler RefreshTokenResolveHandler } func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { @@ -367,10 +369,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") + tgr.Refresh, err = s.RefreshTokenResolveHandler(r) tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest + if err != nil { + return "", nil, err } } return gt, tgr, nil From d92fb72678e88db4702278347e047c29eb3e5c02 Mon Sep 17 00:00:00 2001 From: attilabanga Date: Wed, 26 Mar 2025 12:50:50 +0100 Subject: [PATCH 3/3] BearerAuth function changed to AccessTokenResolveHandler BearerAuth function changed to AccessTokenResolveHandler removed unused dep setter --- server/handler.go | 27 +++++++++++++++++++ server/handler_test.go | 57 +++++++++++++++++++++++++++++++++++++++++ server/server.go | 20 +++------------ server/server_config.go | 10 ++++++++ 4 files changed, 97 insertions(+), 17 deletions(-) diff --git a/server/handler.go b/server/handler.go index a3ea93d..81d54a1 100755 --- a/server/handler.go +++ b/server/handler.go @@ -3,6 +3,7 @@ package server import ( "context" "net/http" + "strings" "time" "github.com/go-oauth2/oauth2/v4" @@ -54,6 +55,9 @@ type ( // Handler to fetch the refresh token from the request RefreshTokenResolveHandler func(r *http.Request) (string, error) + + // Handler to fetch the access token from the request + AccessTokenResolveHandler func(r *http.Request) (string, bool) ) // ClientFormHandler get client data from form @@ -92,3 +96,26 @@ func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) { return c.Value, nil } + +func AccessTokenDefaultResolveHandler(r *http.Request) (string, bool) { + token := "" + auth := r.Header.Get("Authorization") + prefix := "Bearer " + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +func AccessTokenCookieResolveHandler(r *http.Request) (string, bool) { + c, err := r.Cookie("access_token") + if err != nil { + return "", false + } + + return c.Value, true +} diff --git a/server/handler_test.go b/server/handler_test.go index e771556..dac0337 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -62,3 +62,60 @@ func TestRefreshTokenCookieResolveHandler(t *testing.T) { So(token, ShouldBeEmpty) }) } + +func TestAccessTokenDefaultHandler(t *testing.T) { + Convey("Request Has Header", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.Header.Add("Authorization", "Bearer test_token") + + token, ok := AccessTokenDefaultResolveHandler(r) + So(ok, ShouldBeTrue) + So(token, ShouldEqual, "test_token") + }) + + Convey("Request Has FormValue", t, func() { + f := url.Values{} + f.Add("access_token", "test_token") + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, ok := AccessTokenDefaultResolveHandler(r) + So(ok, ShouldBeTrue) + So(token, ShouldEqual, "test_token") + }) + + Convey("Request Has Nothing", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + + token, ok := AccessTokenDefaultResolveHandler(r) + So(ok, ShouldBeFalse) + So(token, ShouldBeEmpty) + }) +} + +func TestAccessTokenCookieHandler(t *testing.T) { + Convey("Request Has Cookie", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.AddCookie(&http.Cookie{ + Name: "access_token", + Value: "test_token", + HttpOnly: true, + Path: "/", + Domain: ".example.com", + Expires: time.Now().Add(time.Hour), + }) + + token, ok := AccessTokenCookieResolveHandler(r) + So(ok, ShouldBeTrue) + So(token, ShouldEqual, "test_token") + }) + + Convey("Request Has No Cookie", t, func() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + + token, ok := AccessTokenCookieResolveHandler(r) + So(ok, ShouldBeFalse) + So(token, ShouldBeEmpty) + }) +} diff --git a/server/server.go b/server/server.go index 64480fe..f4dba2d 100755 --- a/server/server.go +++ b/server/server.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "time" "github.com/go-oauth2/oauth2/v4" @@ -28,6 +27,7 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server { // default handlers srv.ClientInfoHandler = ClientBasicHandler srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler + srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied @@ -58,6 +58,7 @@ type Server struct { AuthorizeScopeHandler AuthorizeScopeHandler ResponseTokenHandler ResponseTokenHandler RefreshTokenResolveHandler RefreshTokenResolveHandler + AccessTokenResolveHandler AccessTokenResolveHandler } func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { @@ -571,27 +572,12 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head return data, statusCode, re.Header } -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() - accessToken, ok := s.BearerAuth(r) + accessToken, ok := s.AccessTokenResolveHandler(r) if !ok { return nil, errors.ErrInvalidAccessToken } diff --git a/server/server_config.go b/server/server_config.go index d4f0404..660da8d 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -93,3 +93,13 @@ func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { func (s *Server) SetResponseTokenHandler(handler ResponseTokenHandler) { s.ResponseTokenHandler = handler } + +// SetRefreshTokenResolveHandler refresh token resolver +func (s *Server) SetRefreshTokenResolveHandler(handler RefreshTokenResolveHandler) { + s.RefreshTokenResolveHandler = handler +} + +// SetAccessTokenResolveHandler access token resolver +func (s *Server) SetAccessTokenResolveHandler(handler AccessTokenResolveHandler) { + s.AccessTokenResolveHandler = handler +}