Skip to content

Commit

Permalink
refactor: simplify handler and test logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Oct 18, 2024
1 parent dcaecc5 commit 24db6b9
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 66 deletions.
17 changes: 9 additions & 8 deletions device_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic
return request, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method))
}
if err := r.ParseForm(); err != nil {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error()))
return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error()))
}
if len(r.PostForm) == 0 {
return request, errorsx.WithStack(ErrInvalidRequest.WithHint("The POST body can not be empty."))
Expand All @@ -44,11 +44,11 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic
request.Client = client

if !client.GetGrantTypes().Has(string(GrantTypeDeviceCode)) {
return nil, errorsx.WithStack(ErrInvalidGrant.WithHint("The requested OAuth 2.0 Client does not have the 'urn:ietf:params:oauth:grant-type:device_code' grant."))
return request, errorsx.WithStack(ErrInvalidGrant.WithHint("The requested OAuth 2.0 Client does not have the 'urn:ietf:params:oauth:grant-type:device_code' grant."))
}

if err := f.validateDeviceScope(ctx, r, request); err != nil {
return nil, err
return request, err
}

if err := f.validateAudience(ctx, r, request); err != nil {
Expand All @@ -59,12 +59,13 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic
}

func (f *Fosite) validateDeviceScope(ctx context.Context, r *http.Request, request *DeviceRequest) error {
scope := RemoveEmpty(strings.Split(request.Form.Get("scope"), " "))
for _, permission := range scope {
if !f.Config.GetScopeStrategy(ctx)(request.Client.GetScopes(), permission) {
return errorsx.WithStack(ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", permission))
scopes := RemoveEmpty(strings.Split(request.Form.Get("scope"), " "))
scopeStrategy := f.Config.GetScopeStrategy(ctx)
for _, scope := range scopes {
if !scopeStrategy(request.Client.GetScopes(), scope) {
return errorsx.WithStack(ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", scope))
}
}
request.SetRequestedScopes(scope)
request.SetRequestedScopes(scopes)
return nil
}
88 changes: 48 additions & 40 deletions device_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@ import (
func TestNewDeviceRequestWithPublicClient(t *testing.T) {
ctrl := gomock.NewController(t)
store := internal.NewMockStorage(ctrl)
client := &DefaultClient{ID: "client_id"}
deviceClient := &DefaultClient{ID: "client_id"}
deviceClient.Public = true
deviceClient.Scopes = []string{"17", "42"}
deviceClient.Audience = []string{"aud2"}
deviceClient.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}

authCodeClient := &DefaultClient{ID: "client_id_2"}
authCodeClient.Public = true
authCodeClient.Scopes = []string{"17", "42"}
authCodeClient.GrantTypes = []string{"authorization_code"}

defer ctrl.Finish()
config := &Config{ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
fosite := &Fosite{Store: store, Config: config}
Expand Down Expand Up @@ -63,40 +73,30 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
},
expectedError: ErrInvalidScope,
}, {
description: "fails because audience not allowed",
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42"},
"audience": {"aud"},
"audience": {"random_aud"},
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.Audience = []string{"aud2"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
},
expectedError: ErrInvalidRequest,
}, {
description: "fails because it doesn't have the proper grant",
form: url.Values{
"client_id": {"client_id"},
"client_id": {"client_id_2"},
"scope": {"17 42"},
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"authorization_code"}
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id_2")).Return(authCodeClient, nil)
},
expectedError: ErrInvalidGrant,
}, {
Expand All @@ -107,10 +107,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil)
},
}} {
t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) {
Expand All @@ -123,10 +120,8 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
}

ar, err := fosite.NewDeviceRequest(context.Background(), r)
if c.expectedError != nil {
assert.EqualError(t, err, c.expectedError.Error())
} else {
require.NoError(t, err)
require.ErrorIs(t, err, c.expectedError)
if c.expectedError == nil {
assert.NotNil(t, ar.GetRequestedAt())
}
})
Expand All @@ -141,15 +136,21 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
defer ctrl.Finish()
config := &Config{ClientSecretsHasher: hasher, ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
fosite := &Fosite{Store: store, Config: config}

client.Public = false
client.Secret = []byte("client_secret")
client.Scopes = []string{"foo", "bar"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}

for k, c := range []struct {
header http.Header
form url.Values
method string
expectedError error
mock func()
expect DeviceRequester
description string
}{
// No client authn provided
{
form: url.Values{
"client_id": {"client_id"},
Expand All @@ -159,14 +160,26 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = false
client.Secret = []byte("client_secret")
client.Scopes = []string{"foo", "bar"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New(""))
},
description: "Should failed becaue no client authn provided.",
},
{
form: url.Values{
"client_id": {"client_id2"},
"scope": {"foo bar"},
},
header: http.Header{
"Authorization": {basicAuth("client_id", "client_secret")},
},
expectedError: ErrInvalidRequest,
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil)
},
description: "should fail because different client is used in authn than in form",
},
// success
{
form: url.Values{
"client_id": {"client_id"},
Expand All @@ -178,15 +191,12 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = false
client.Secret = []byte("client_secret")
client.Scopes = []string{"foo", "bar"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil)
},
description: "should succeed",
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) {
c.mock()
r := &http.Request{
Header: c.header,
Expand All @@ -196,11 +206,9 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) {
}

req, err := fosite.NewDeviceRequest(context.Background(), r)
if c.expectedError != nil {
assert.EqualError(t, err, c.expectedError.Error())
} else {
require.NoError(t, err)
assert.NotNil(t, req.GetRequestedAt())
require.ErrorIs(t, err, c.expectedError)
if c.expectedError == nil {
assert.NotZero(t, req.GetRequestedAt())
}
})
}
Expand Down
18 changes: 0 additions & 18 deletions device_request_test.go

This file was deleted.

11 changes: 11 additions & 0 deletions fosite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
. "github.com/ory/fosite"
"github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/handler/par"
"github.com/ory/fosite/handler/rfc8628"
)

func TestAuthorizeEndpointHandlers(t *testing.T) {
Expand All @@ -25,6 +26,16 @@ func TestAuthorizeEndpointHandlers(t *testing.T) {
assert.Equal(t, hs[0], h)
}

func TestDeviceAuthorizeEndpointHandlers(t *testing.T) {
h := &rfc8628.DeviceAuthHandler{}
hs := DeviceEndpointHandlers{}
hs.Append(h)
hs.Append(h)
hs.Append(&rfc8628.DeviceAuthHandler{})
assert.Len(t, hs, 1)
assert.Equal(t, hs[0], h)
}

func TestTokenEndpointHandlers(t *testing.T) {
h := &oauth2.AuthorizeExplicitGrantHandler{}
hs := TokenEndpointHandlers{}
Expand Down

0 comments on commit 24db6b9

Please sign in to comment.