Skip to content

Refresh and Access token resolve handler #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"net/http"
"strings"
"time"

"github.com/go-oauth2/oauth2/v4"
Expand Down Expand Up @@ -49,8 +50,14 @@ type (
// ExtensionFieldsHandler in response to the access token with the extension of the field
ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{})

// ResponseTokenHandler response token handing
// ResponseTokenHandler response token handling
ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error

// Handler to fetch the refresh token from the request
RefreshTokenResolveHandler func(r *http.Request) (string, error)

// Handler to fetch the access token from the request
AccessTokenResolveHandler func(r *http.Request) (string, bool)
)

// ClientFormHandler get client data from form
Expand All @@ -71,3 +78,44 @@ func ClientBasicHandler(r *http.Request) (string, string, error) {
}
return username, password, nil
}

func RefreshTokenFormResolveHandler(r *http.Request) (string, error) {
rt := r.FormValue("refresh_token")
if rt == "" {
return "", errors.ErrInvalidRequest
}

return rt, nil
}

func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) {
c, err := r.Cookie("refresh_token")
if err != nil {
return "", errors.ErrInvalidRequest
}

return c.Value, nil
}

func AccessTokenDefaultResolveHandler(r *http.Request) (string, bool) {
token := ""
auth := r.Header.Get("Authorization")
prefix := "Bearer "

if auth != "" && strings.HasPrefix(auth, prefix) {
token = auth[len(prefix):]
} else {
token = r.FormValue("access_token")
}

return token, token != ""
}

func AccessTokenCookieResolveHandler(r *http.Request) (string, bool) {
c, err := r.Cookie("access_token")
if err != nil {
return "", false
}

return c.Value, true
}
121 changes: 121 additions & 0 deletions server/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package server

import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/go-oauth2/oauth2/v4/errors"
. "github.com/smartystreets/goconvey/convey"
)

func TestRefreshTokenFormResolveHandler(t *testing.T) {
Convey("Correct Request", t, func() {
f := url.Values{}
f.Add("refresh_token", "test_token")

r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")

token, err := RefreshTokenFormResolveHandler(r)
So(err, ShouldBeNil)
So(token, ShouldEqual, "test_token")
})

Convey("Missing Refresh Token", t, func() {
r := httptest.NewRequest("POST", "/", nil)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")

token, err := RefreshTokenFormResolveHandler(r)
So(err, ShouldBeError, errors.ErrInvalidRequest)
So(token, ShouldBeEmpty)
})
}

func TestRefreshTokenCookieResolveHandler(t *testing.T) {
Convey("Correct Request", t, func() {
r := httptest.NewRequest(http.MethodPost, "/", nil)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.AddCookie(&http.Cookie{
Name: "refresh_token",
Value: "test_token",
HttpOnly: true,
Path: "/",
Domain: ".example.com",
Expires: time.Now().Add(time.Hour),
})

token, err := RefreshTokenCookieResolveHandler(r)
So(err, ShouldBeNil)
So(token, ShouldEqual, "test_token")
})

Convey("Missing Refresh Token", t, func() {
r := httptest.NewRequest("POST", "/", nil)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")

token, err := RefreshTokenCookieResolveHandler(r)
So(err, ShouldBeError, errors.ErrInvalidRequest)
So(token, ShouldBeEmpty)
})
}

func TestAccessTokenDefaultHandler(t *testing.T) {
Convey("Request Has Header", t, func() {
r := httptest.NewRequest(http.MethodPost, "/", nil)
r.Header.Add("Authorization", "Bearer test_token")

token, ok := AccessTokenDefaultResolveHandler(r)
So(ok, ShouldBeTrue)
So(token, ShouldEqual, "test_token")
})

Convey("Request Has FormValue", t, func() {
f := url.Values{}
f.Add("access_token", "test_token")
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")

token, ok := AccessTokenDefaultResolveHandler(r)
So(ok, ShouldBeTrue)
So(token, ShouldEqual, "test_token")
})

Convey("Request Has Nothing", t, func() {
r := httptest.NewRequest(http.MethodPost, "/", nil)

token, ok := AccessTokenDefaultResolveHandler(r)
So(ok, ShouldBeFalse)
So(token, ShouldBeEmpty)
})
}

func TestAccessTokenCookieHandler(t *testing.T) {
Convey("Request Has Cookie", t, func() {
r := httptest.NewRequest(http.MethodPost, "/", nil)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.AddCookie(&http.Cookie{
Name: "access_token",
Value: "test_token",
HttpOnly: true,
Path: "/",
Domain: ".example.com",
Expires: time.Now().Add(time.Hour),
})

token, ok := AccessTokenCookieResolveHandler(r)
So(ok, ShouldBeTrue)
So(token, ShouldEqual, "test_token")
})

Convey("Request Has No Cookie", t, func() {
r := httptest.NewRequest(http.MethodPost, "/", nil)

token, ok := AccessTokenCookieResolveHandler(r)
So(ok, ShouldBeFalse)
So(token, ShouldBeEmpty)
})
}
30 changes: 9 additions & 21 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/go-oauth2/oauth2/v4"
Expand All @@ -25,8 +24,10 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server {
Manager: manager,
}

// default handler
// default handlers
srv.ClientInfoHandler = ClientBasicHandler
srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler
srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler

srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
return "", errors.ErrAccessDenied
Expand Down Expand Up @@ -56,6 +57,8 @@ type Server struct {
AccessTokenExpHandler AccessTokenExpHandler
AuthorizeScopeHandler AuthorizeScopeHandler
ResponseTokenHandler ResponseTokenHandler
RefreshTokenResolveHandler RefreshTokenResolveHandler
AccessTokenResolveHandler AccessTokenResolveHandler
}

func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
Expand Down Expand Up @@ -367,10 +370,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau
case oauth2.ClientCredentials:
tgr.Scope = r.FormValue("scope")
case oauth2.Refreshing:
tgr.Refresh = r.FormValue("refresh_token")
tgr.Refresh, err = s.RefreshTokenResolveHandler(r)
tgr.Scope = r.FormValue("scope")
if tgr.Refresh == "" {
return "", nil, errors.ErrInvalidRequest
if err != nil {
return "", nil, err
}
}
return gt, tgr, nil
Expand Down Expand Up @@ -569,27 +572,12 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head
return data, statusCode, re.Header
}

// BearerAuth parse bearer token
func (s *Server) BearerAuth(r *http.Request) (string, bool) {
auth := r.Header.Get("Authorization")
prefix := "Bearer "
token := ""

if auth != "" && strings.HasPrefix(auth, prefix) {
token = auth[len(prefix):]
} else {
token = r.FormValue("access_token")
}

return token, token != ""
}

// ValidationBearerToken validation the bearer tokens
// https://tools.ietf.org/html/rfc6750
func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
ctx := r.Context()

accessToken, ok := s.BearerAuth(r)
accessToken, ok := s.AccessTokenResolveHandler(r)
if !ok {
return nil, errors.ErrInvalidAccessToken
}
Expand Down
10 changes: 10 additions & 0 deletions server/server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,13 @@ func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) {
func (s *Server) SetResponseTokenHandler(handler ResponseTokenHandler) {
s.ResponseTokenHandler = handler
}

// SetRefreshTokenResolveHandler refresh token resolver
func (s *Server) SetRefreshTokenResolveHandler(handler RefreshTokenResolveHandler) {
s.RefreshTokenResolveHandler = handler
}

// SetAccessTokenResolveHandler access token resolver
func (s *Server) SetAccessTokenResolveHandler(handler AccessTokenResolveHandler) {
s.AccessTokenResolveHandler = handler
}