Skip to content

Commit

Permalink
feat: support more claims in password grant
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Oct 28, 2024
1 parent 9cc5f28 commit efa5d5a
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 27 deletions.
5 changes: 2 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ require (
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/oleiade/reflections v1.0.1
github.com/ory/analytics-go/v5 v5.0.1
github.com/ory/fosite v0.47.0
github.com/ory/fosite v0.47.1-0.20241028132122-b35b62fed16e
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe
github.com/ory/graceful v0.1.3
github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88
Expand Down Expand Up @@ -69,8 +69,6 @@ require (
golang.org/x/tools v0.23.0
)

require github.com/hashicorp/go-cleanhttp v0.5.2 // indirect

require (
code.dny.dev/ssrf v0.2.0 // indirect
dario.cat/mergo v1.0.0 // indirect
Expand Down Expand Up @@ -147,6 +145,7 @@ require (
github.com/gorilla/websocket v1.5.0 // indirect
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/huandu/xstrings v1.4.0 // indirect
github.com/imdario/mergo v0.3.16 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,8 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp
github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0=
github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E=
github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk=
github.com/ory/fosite v0.47.0 h1:Iqu5uhx54JqZQPn2hRhqjESrmRRyQb00uJjfEi1a1QI=
github.com/ory/fosite v0.47.0/go.mod h1:5U6c9nOLxyTdD/qrFv7N88TSxkdk5Wq8NzvB7UViDP0=
github.com/ory/fosite v0.47.1-0.20241028132122-b35b62fed16e h1:aRhPoQt0QFrjQVrNDGiGQcT9cY/o8iqqBkMdoiyznjI=
github.com/ory/fosite v0.47.1-0.20241028132122-b35b62fed16e/go.mod h1:5U6c9nOLxyTdD/qrFv7N88TSxkdk5Wq8NzvB7UViDP0=
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc=
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI=
github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8=
Expand Down
14 changes: 8 additions & 6 deletions internal/kratos/fake_kratos.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"

"github.com/ory/fosite"
client "github.com/ory/kratos-client-go"
)

type (
Expand All @@ -17,9 +18,10 @@ type (
)

const (
FakeSessionID = "fake-kratos-session-id"
FakeUsername = "fake-kratos-username"
FakePassword = "fake-kratos-password" // nolint: gosec
FakeSessionID = "fake-kratos-session-id"
FakeUsername = "fake-kratos-username"
FakePassword = "fake-kratos-password" // nolint: gosec
FakeIdentityID = "fake-kratos-identity-id"
)

var _ Client = new(FakeKratos)
Expand All @@ -35,11 +37,11 @@ func (f *FakeKratos) DisableSession(_ context.Context, identityProviderSessionID
return nil
}

func (f *FakeKratos) Authenticate(_ context.Context, username, password string) error {
func (f *FakeKratos) Authenticate(_ context.Context, username, password string) (*client.Session, error) {
if username == FakeUsername && password == FakePassword {
return nil
return &client.Session{Identity: &client.Identity{Id: FakeIdentityID}}, nil
}
return fosite.ErrNotFound
return nil, fosite.ErrNotFound
}

func (f *FakeKratos) Reset() {
Expand Down
14 changes: 7 additions & 7 deletions internal/kratos/kratos.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type (
}
Client interface {
DisableSession(ctx context.Context, identityProviderSessionID string) error
Authenticate(ctx context.Context, name, secret string) error
Authenticate(ctx context.Context, name, secret string) (*client.Session, error)
}
Default struct {
dependencies
Expand All @@ -42,7 +42,7 @@ func New(d dependencies) Client {
return &Default{dependencies: d}
}

func (k *Default) Authenticate(ctx context.Context, name, secret string) (err error) {
func (k *Default) Authenticate(ctx context.Context, name, secret string) (session *client.Session, err error) {
ctx, span := k.Tracer(ctx).Tracer().Start(ctx, "kratos.Authenticate")
otelx.End(span, &err)

Expand All @@ -52,28 +52,28 @@ func (k *Default) Authenticate(ctx context.Context, name, secret string) (err er
span.SetAttributes(attribute.Bool("skipped", true))
span.SetAttributes(attribute.String("reason", "kratos public url not set"))

return errors.New("kratos public url not set")
return nil, errors.New("kratos public url not set")
}

kratos := k.newKratosClient(ctx, publicURL)

flow, _, err := kratos.FrontendAPI.CreateNativeLoginFlow(ctx).Execute()
if err != nil {
return err
return nil, err
}

_, _, err = kratos.FrontendAPI.UpdateLoginFlow(ctx).Flow(flow.Id).UpdateLoginFlowBody(client.UpdateLoginFlowBody{
res, _, err := kratos.FrontendAPI.UpdateLoginFlow(ctx).Flow(flow.Id).UpdateLoginFlowBody(client.UpdateLoginFlowBody{
UpdateLoginFlowWithPasswordMethod: &client.UpdateLoginFlowWithPasswordMethod{
Method: "password",
Identifier: name,
Password: secret,
},
}).Execute()
if err != nil {
return fosite.ErrNotFound.WithWrap(err)
return nil, fosite.ErrNotFound.WithWrap(err)
}

return nil
return &res.Session, nil
}

func (k *Default) DisableSession(ctx context.Context, identityProviderSessionID string) (err error) {
Expand Down
9 changes: 7 additions & 2 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,8 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
}

if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) ||
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) {
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) ||
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypePassword)) {
var accessTokenKeyID string
if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" {
accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx)
Expand All @@ -975,9 +976,13 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
}

// only for client_credentials, otherwise Authentication is included in session
if accessRequest.GetGrantTypes().ExactOne("client_credentials") {
if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) {
session.Subject = accessRequest.GetClient().GetID()
}
// only for password grant, otherwise Authentication is included in session
if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypePassword)) {
session.Subject = accessRequest.GetRequestForm().Get("username")
}
session.ClientID = accessRequest.GetClient().GetID()
session.KID = accessTokenKeyID
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(ctx).String()
Expand Down
103 changes: 99 additions & 4 deletions oauth2/oauth2_rop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@ package oauth2_test

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

"github.com/ory/fosite/compose"
"github.com/ory/fosite/token/jwt"
hydra "github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/fositex"
"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/internal/testhelpers"
hydraoauth2 "github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
)

Expand All @@ -32,7 +41,12 @@ func TestResourceOwnerPasswordGrant(t *testing.T) {
secret := uuid.New().String()
client := &hydra.Client{
Secret: secret,
GrantTypes: []string{"password"},
GrantTypes: []string{"password", "refresh_token"},
Scope: "offline",
Lifespans: hydra.Lifespans{
PasswordGrantAccessTokenLifespan: x.NullDuration{Duration: 1 * time.Hour, Valid: true},
PasswordGrantRefreshTokenLifespan: x.NullDuration{Duration: 1 * time.Hour, Valid: true},
},
}
require.NoError(t, reg.ClientManager().CreateClient(ctx, client))

Expand All @@ -44,15 +58,96 @@ func TestResourceOwnerPasswordGrant(t *testing.T) {
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: []string{"offline"},
}

hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
assert.Equal(t, r.Header.Get("Authorization"), "Bearer secret value")

var hookReq hydraoauth2.TokenHookRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
require.NotEmpty(t, hookReq.Session)
require.Equal(t, kratos.FakeIdentityID, hookReq.Session.Extra["identity_id"])
require.NotEmpty(t, hookReq.Request)
require.ElementsMatch(t, []string{}, hookReq.Request.GrantedAudience)

claims := map[string]interface{}{
"hooked": true,
"identity_id": kratos.FakeIdentityID,
}
if hookReq.Request.GrantTypes[0] == "refresh_token" {
claims["refreshed"] = true
}

hookResp := hydraoauth2.TokenHookResponse{
Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
}

w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(&hookResp))
}))
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{
URL: hs.URL,
Auth: &config.Auth{
Type: "api_key",
Config: config.AuthConfig{
In: "header",
Name: "Authorization",
Value: "Bearer secret value",
},
},
})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt")

t.Run("case=get ROP grant token with valid username and password", func(t *testing.T) {
token, err := oauth2Config.PasswordCredentialsToken(ctx, kratos.FakeUsername, kratos.FakePassword)
require.NoError(t, err)
require.NotEmpty(t, token.AccessToken)
i := testhelpers.IntrospectToken(t, oauth2Config, token.AccessToken, adminTS)
assert.True(t, i.Get("active").Bool(), "%s", i)
assert.EqualValues(t, oauth2Config.ClientID, i.Get("client_id").String(), "%s", i)

// Access token should have hook and identity_id claims
jwtAT, err := jwt.Parse(token.AccessToken, func(token *jwt.Token) (interface{}, error) {
return reg.AccessTokenJWTStrategy().GetPublicKey(ctx)
})
require.NoError(t, err)
assert.Equal(t, kratos.FakeUsername, jwtAT.Claims["sub"])
assert.Equal(t, kratos.FakeIdentityID, jwtAT.Claims["ext"].(map[string]any)["identity_id"].(string))
assert.True(t, jwtAT.Claims["ext"].(map[string]any)["hooked"].(bool))

t.Run("case=introspect token", func(t *testing.T) {
// Introspected token should have hook and identity_id claims
i := testhelpers.IntrospectToken(t, oauth2Config, token.AccessToken, adminTS)
assert.True(t, i.Get("active").Bool(), "%s", i)
assert.Equal(t, kratos.FakeUsername, i.Get("sub").String(), "%s", i)
assert.Equal(t, kratos.FakeIdentityID, i.Get("ext.identity_id").String(), "%s", i)
assert.True(t, i.Get("ext.hooked").Bool(), "%s", i)
assert.EqualValues(t, oauth2Config.ClientID, i.Get("client_id").String(), "%s", i)
})

t.Run("case=refresh token", func(t *testing.T) {
// Refreshed access token should have hook and identity_id claims
require.NotEmpty(t, token.RefreshToken)
token.Expiry = token.Expiry.Add(-time.Hour * 24)
refreshedToken, err := oauth2Config.TokenSource(context.Background(), token).Token()
require.NoError(t, err)

require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken)
require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken)

jwtAT, err := jwt.Parse(refreshedToken.AccessToken, func(token *jwt.Token) (interface{}, error) {
return reg.AccessTokenJWTStrategy().GetPublicKey(ctx)
})
require.NoError(t, err)
assert.Equal(t, kratos.FakeUsername, jwtAT.Claims["sub"])
assert.Equal(t, kratos.FakeIdentityID, jwtAT.Claims["ext"].(map[string]any)["identity_id"].(string))
assert.True(t, jwtAT.Claims["ext"].(map[string]any)["hooked"].(bool))
assert.True(t, jwtAT.Claims["ext"].(map[string]any)["refreshed"].(bool))
})
})

t.Run("case=access denied for invalid password", func(t *testing.T) {
Expand Down
14 changes: 14 additions & 0 deletions oauth2/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,17 @@ func (s *Session) UnmarshalJSON(original []byte) (err error) {

return nil
}

// GetExtraClaims implements ExtraClaimsSession for Session.
// The returned value can be modified in-place.
func (s *Session) GetExtraClaims() map[string]interface{} {
if s == nil {
return nil
}

if s.Extra == nil {
s.Extra = make(map[string]interface{})
}

return s.Extra
}
12 changes: 9 additions & 3 deletions persistence/sql/persister_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@

package sql

import "context"
import (
"context"
)

func (p *Persister) Authenticate(ctx context.Context, name, secret string) error {
return p.r.Kratos().Authenticate(ctx, name, secret)
func (p *Persister) Authenticate(ctx context.Context, name, secret string) (string, error) {
session, err := p.r.Kratos().Authenticate(ctx, name, secret)
if err != nil {
return "", err
}
return session.Identity.Id, nil
}

0 comments on commit efa5d5a

Please sign in to comment.