diff --git a/compose/compose.go b/compose/compose.go index 683564537..372c8d3c9 100644 --- a/compose/compose.go +++ b/compose/compose.go @@ -32,24 +32,24 @@ type Factory func(config fosite.Configurator, storage interface{}, strategy inte // Compose takes a config, a storage, a strategy and handlers to instantiate an OAuth2Provider: // -// import "github.com/ory/fosite/compose" +// import "github.com/ory/fosite/compose" // -// // var storage = new(MyFositeStorage) -// var config = Config { -// AccessTokenLifespan: time.Minute * 30, -// // check Config for further configuration options -// } +// // var storage = new(MyFositeStorage) +// var config = Config { +// AccessTokenLifespan: time.Minute * 30, +// // check Config for further configuration options +// } // -// var strategy = NewOAuth2HMACStrategy(config) +// var strategy = NewOAuth2HMACStrategy(config) // -// var oauth2Provider = Compose( -// config, -// storage, -// strategy, -// NewOAuth2AuthorizeExplicitHandler, -// OAuth2ClientCredentialsGrantFactory, -// // for a complete list refer to the docs of this package -// ) +// var oauth2Provider = Compose( +// config, +// storage, +// strategy, +// NewOAuth2AuthorizeExplicitHandler, +// OAuth2ClientCredentialsGrantFactory, +// // for a complete list refer to the docs of this package +// ) // // Compose makes use of interface{} types in order to be able to handle a all types of stores, strategies and handlers. func Compose(config *fosite.Config, storage interface{}, strategy interface{}, factories ...Factory) fosite.OAuth2Provider { @@ -91,8 +91,10 @@ func ComposeAllEnabled(config *fosite.Config, storage interface{}, key interface }, OAuth2AuthorizeExplicitFactory, OAuth2AuthorizeImplicitFactory, + OAuth2AuthorizeDeviceFactory, OAuth2ClientCredentialsGrantFactory, OAuth2RefreshTokenGrantFactory, + OAuth2DeviceAuthorizeFactory, OAuth2ResourceOwnerPasswordCredentialsFactory, RFC7523AssertionGrantFactory, @@ -100,11 +102,13 @@ func ComposeAllEnabled(config *fosite.Config, storage interface{}, key interface OpenIDConnectImplicitFactory, OpenIDConnectHybridFactory, OpenIDConnectRefreshFactory, + OpenIDConnectDeviceFactory, OAuth2TokenIntrospectionFactory, OAuth2TokenRevocationFactory, OAuth2PKCEFactory, PushedAuthorizeHandlerFactory, + OAuth2DevicePKCEFactory, ) } diff --git a/compose/compose_oauth2.go b/compose/compose_oauth2.go index f80316f76..174377c70 100644 --- a/compose/compose_oauth2.go +++ b/compose/compose_oauth2.go @@ -125,3 +125,25 @@ func OAuth2StatelessJWTIntrospectionFactory(config fosite.Configurator, storage Config: config, } } + +func OAuth2AuthorizeDeviceFactory(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { + return &oauth2.AuthorizeDeviceGrantTypeHandler{ + DeviceCodeStrategy: strategy.(oauth2.DeviceCodeStrategy), + UserCodeStrategy: strategy.(oauth2.UserCodeStrategy), + AccessTokenStrategy: strategy.(oauth2.AccessTokenStrategy), + RefreshTokenStrategy: strategy.(oauth2.RefreshTokenStrategy), + AuthorizeCodeStrategy: strategy.(oauth2.AuthorizeCodeStrategy), + CoreStorage: storage.(oauth2.CoreStorage), + Config: config, + } +} + +func OAuth2DeviceAuthorizeFactory(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { + return &oauth2.DeviceAuthorizationHandler{ + DeviceCodeStorage: storage.(oauth2.DeviceCodeStorage), + UserCodeStorage: storage.(oauth2.UserCodeStorage), + DeviceCodeStrategy: strategy.(oauth2.DeviceCodeStrategy), + UserCodeStrategy: strategy.(oauth2.UserCodeStrategy), + Config: config, + } +} diff --git a/compose/compose_openid.go b/compose/compose_openid.go index 4ced36a55..ca2bfc1bb 100644 --- a/compose/compose_openid.go +++ b/compose/compose_openid.go @@ -62,7 +62,8 @@ func OpenIDConnectImplicitFactory(config fosite.Configurator, storage interface{ AuthorizeImplicitGrantTypeHandler: &oauth2.AuthorizeImplicitGrantTypeHandler{ AccessTokenStrategy: strategy.(oauth2.AccessTokenStrategy), AccessTokenStorage: storage.(oauth2.AccessTokenStorage), - Config: config, + + Config: config, }, Config: config, IDTokenHandleHelper: &openid.IDTokenHandleHelper{ @@ -97,3 +98,20 @@ func OpenIDConnectHybridFactory(config fosite.Configurator, storage interface{}, OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), } } + +// OpenIDConnectDeviceFactory creates an OpenID Connect device ("device code flow") grant handler. +// +// **Important note:** You must add this handler *after* you have added an OAuth2 authorize code handler! +func OpenIDConnectDeviceFactory(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { + return &openid.OpenIDConnectDeviceHandler{ + CoreStorage: storage.(oauth2.CoreStorage), + DeviceCodeStrategy: strategy.(oauth2.DeviceCodeStrategy), + UserCodeStrategy: strategy.(oauth2.UserCodeStrategy), + OpenIDConnectRequestStorage: storage.(openid.OpenIDConnectRequestStorage), + IDTokenHandleHelper: &openid.IDTokenHandleHelper{ + IDTokenStrategy: strategy.(openid.OpenIDConnectTokenStrategy), + }, + OpenIDConnectRequestValidator: openid.NewOpenIDConnectRequestValidator(strategy.(jwt.Signer), config), + Config: config, + } +} diff --git a/compose/compose_pkce.go b/compose/compose_pkce.go index b5d2ddc33..4dfc003d0 100644 --- a/compose/compose_pkce.go +++ b/compose/compose_pkce.go @@ -35,3 +35,14 @@ func OAuth2PKCEFactory(config fosite.Configurator, storage interface{}, strategy Config: config, } } + +func OAuth2DevicePKCEFactory(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { + return &pkce.HandlerDevice{ + CoreStorage: storage.(oauth2.CoreStorage), + DeviceCodeStrategy: strategy.(oauth2.DeviceCodeStrategy), + UserCodeStrategy: strategy.(oauth2.UserCodeStrategy), + AuthorizeCodeStrategy: strategy.(oauth2.AuthorizeCodeStrategy), + Storage: storage.(pkce.PKCERequestStorage), + Config: config, + } +} diff --git a/compose/compose_strategy.go b/compose/compose_strategy.go index dda18f41a..b31c42816 100644 --- a/compose/compose_strategy.go +++ b/compose/compose_strategy.go @@ -45,6 +45,7 @@ type HMACSHAStrategyConfigurator interface { fosite.GlobalSecretProvider fosite.RotatedGlobalSecretsProvider fosite.HMACHashingProvider + fosite.DeviceAndUserCodeLifespanProvider } func NewOAuth2HMACStrategy(config HMACSHAStrategyConfigurator) *oauth2.HMACSHAStrategy { diff --git a/config.go b/config.go index 1edc5bc5f..5d4d31e4c 100644 --- a/config.go +++ b/config.go @@ -19,6 +19,15 @@ type AuthorizeCodeLifespanProvider interface { GetAuthorizeCodeLifespan(ctx context.Context) time.Duration } +type DeviceAndUserCodeLifespanProvider interface { + GetDeviceAndUserCodeLifespan(ctx context.Context) time.Duration +} + +type DeviceUriProvider interface { + GetDeviceVerificationURL(ctx context.Context) string + GetDeviceAuthTokenPollingInterval(ctx context.Context) time.Duration +} + // RefreshTokenLifespanProvider returns the provider for configuring the refresh token lifespan. type RefreshTokenLifespanProvider interface { // GetRefreshTokenLifespan returns the refresh token lifespan. @@ -272,6 +281,10 @@ type PushedAuthorizeRequestHandlersProvider interface { GetPushedAuthorizeEndpointHandlers(ctx context.Context) PushedAuthorizeEndpointHandlers } +type DeviceAuthorizeEndpointHandlersProvider interface { + GetDeviceAuthorizeEndpointHandlers(ctx context.Context) DeviceAuthorizeEndpointHandlers +} + // UseLegacyErrorFormatProvider returns the provider for configuring whether to use the legacy error format. // // DEPRECATED: Do not use this flag anymore. diff --git a/config_default.go b/config_default.go index 9ff7dc941..1da029d7a 100644 --- a/config_default.go +++ b/config_default.go @@ -80,6 +80,7 @@ var ( _ RevocationHandlersProvider = (*Config)(nil) _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) + _ DeviceAuthorizeEndpointHandlersProvider = (*Config)(nil) ) type Config struct { @@ -93,6 +94,15 @@ type Config struct { // AuthorizeCodeLifespan sets how long an authorize code is going to be valid. Defaults to fifteen minutes. AuthorizeCodeLifespan time.Duration + // Sets how long a device user/device code pair is valid for + DeviceAndUserCodeLifespan time.Duration + + // DeviceAuthTokenPollingInterval sets the interval that clients should check for device code grants + DeviceAuthTokenPollingInterval time.Duration + + // DeviceVerificationURL is the URL of the device verification endpoint, this is is included with the device code request responses + DeviceVerificationURL string + // IDTokenLifespan sets the default id token lifetime. Defaults to one hour. IDTokenLifespan time.Duration @@ -212,6 +222,8 @@ type Config struct { // PushedAuthorizeEndpointHandlers is a list of handlers that are called before the PAR endpoint is served. PushedAuthorizeEndpointHandlers PushedAuthorizeEndpointHandlers + DeviceAuthorizeEndpointHandlers DeviceAuthorizeEndpointHandlers + // GlobalSecret is the global secret used to sign and verify signatures. GlobalSecret []byte @@ -260,6 +272,10 @@ func (c *Config) GetTokenIntrospectionHandlers(ctx context.Context) TokenIntrosp return c.TokenIntrospectionHandlers } +func (c *Config) GetDeviceAuthorizeEndpointHandlers(ctx context.Context) DeviceAuthorizeEndpointHandlers { + return c.DeviceAuthorizeEndpointHandlers +} + func (c *Config) GetRevocationHandlers(ctx context.Context) RevocationHandlers { return c.RevocationHandlers } @@ -378,6 +394,13 @@ func (c *Config) GetAuthorizeCodeLifespan(_ context.Context) time.Duration { return c.AuthorizeCodeLifespan } +func (c *Config) GetDeviceAndUserCodeLifespan(_ context.Context) time.Duration { + if c.AuthorizeCodeLifespan == 0 { + return time.Minute * 10 + } + return c.DeviceAndUserCodeLifespan +} + // GeIDTokenLifespan returns how long an id token should be valid. Defaults to one hour. func (c *Config) GetIDTokenLifespan(_ context.Context) time.Duration { if c.IDTokenLifespan == 0 { @@ -506,3 +529,14 @@ func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) time.Dur func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool { return c.IsPushedAuthorizeEnforced } + +func (c *Config) GetDeviceVerificationURL(ctx context.Context) string { + return c.DeviceVerificationURL +} + +func (c *Config) GetDeviceAuthTokenPollingInterval(ctx context.Context) time.Duration { + if c.DeviceAuthTokenPollingInterval == 0 { + return time.Second * 10 + } + return c.DeviceAuthTokenPollingInterval +} diff --git a/errors.go b/errors.go index abf001272..4da5ed5b3 100644 --- a/errors.go +++ b/errors.go @@ -42,6 +42,10 @@ var ( // ErrInvalidatedAuthorizeCode is an error indicating that an authorization code has been // used previously. ErrInvalidatedAuthorizeCode = errors.New("Authorization code has ben invalidated") + // ErrInvalidatedDeviceCode is an error indicating that a device code has been used previously. + ErrInvalidatedDeviceCode = errors.New("Device code has been invalidated") + // ErrInvalidatedUserCode is an error indicating that a user code has been used previously. + ErrInvalidatedUserCode = errors.New("user code has been invalidated") // ErrSerializationFailure is an error indicating that the transactional capable storage could not guarantee // consistency of Update & Delete operations on the same rows between multiple sessions. ErrSerializationFailure = errors.New("The request could not be completed due to concurrent access") @@ -221,6 +225,11 @@ var ( ErrorField: errJTIKnownName, CodeField: http.StatusBadRequest, } + ErrAuthorizationPending = &RFC6749Error{ + DescriptionField: "The authorization request is still pending as the end user hasn't yet completed the user-interaction steps.", + ErrorField: errAuthorizationPending, + CodeField: http.StatusForbidden, + } ) const ( @@ -258,6 +267,7 @@ const ( errRequestURINotSupportedName = "request_uri_not_supported" errRegistrationNotSupportedName = "registration_not_supported" errJTIKnownName = "jti_known" + errAuthorizationPending = "authorization_pending" ) type ( diff --git a/fosite.go b/fosite.go index 285a6be3f..db1f512aa 100644 --- a/fosite.go +++ b/fosite.go @@ -100,6 +100,20 @@ func (a *PushedAuthorizeEndpointHandlers) Append(h PushedAuthorizeEndpointHandle *a = append(*a, h) } +// DeviceAuthorizeEndpointHandler is a list of DeviceAuthorizeEndpointHandler +type DeviceAuthorizeEndpointHandlers []DeviceAuthorizeEndpointHandler + +// Append adds an AuthorizeEndpointHandler to this list. Ignores duplicates based on reflect.TypeOf. +func (a *DeviceAuthorizeEndpointHandlers) Append(h DeviceAuthorizeEndpointHandler) { + for _, this := range *a { + if reflect.TypeOf(this) == reflect.TypeOf(h) { + return + } + } + + *a = append(*a, h) +} + var _ OAuth2Provider = (*Fosite)(nil) type Configurator interface { @@ -125,6 +139,7 @@ type Configurator interface { AccessTokenLifespanProvider RefreshTokenLifespanProvider AuthorizeCodeLifespanProvider + DeviceAndUserCodeLifespanProvider TokenEntropyProvider RotatedGlobalSecretsProvider GlobalSecretProvider @@ -149,6 +164,8 @@ type Configurator interface { TokenIntrospectionHandlersProvider RevocationHandlersProvider UseLegacyErrorFormatProvider + DeviceAuthorizeEndpointHandlersProvider + DeviceUriProvider } func NewOAuth2Provider(s Storage, c Configurator) *Fosite { diff --git a/handler.go b/handler.go index f1319a610..0432e6036 100644 --- a/handler.go +++ b/handler.go @@ -84,3 +84,13 @@ type PushedAuthorizeEndpointHandler interface { // the pushed authorize request, he must return nil and NOT modify session nor responder neither requester. HandlePushedAuthorizeEndpointRequest(ctx context.Context, requester AuthorizeRequester, responder PushedAuthorizeResponder) error } + +type DeviceAuthorizeEndpointHandler interface { + // HandleDeviceAuthorizeRequest handles a device authorize endpoint request. To extend the handler's capabilities, the http request + // is passed along, if further information retrieval is required. If the handler feels that he is not responsible for + // the device authorize request, he must return nil and NOT modify session nor responder neither requester. + // + // The following spec is a good example of what HandleDeviceAuthorizeRequest should do. + // * https://tools.ietf.org/html/rfc8628#section-3.2 + HandleDeviceAuthorizeEndpointRequest(ctx context.Context, requester Requester, responder DeviceAuthorizeResponder) error +} diff --git a/handler/oauth2/device_authorization.go b/handler/oauth2/device_authorization.go new file mode 100644 index 000000000..39d33c8d8 --- /dev/null +++ b/handler/oauth2/device_authorization.go @@ -0,0 +1,60 @@ +package oauth2 + +import ( + "context" + "fmt" + "time" + + "github.com/ory/fosite" + "github.com/ory/x/errorsx" +) + +// DeviceAuthorizationHandler is a response handler for the Device Authorisation Grant as +// defined in https://tools.ietf.org/html/rfc8628#section-3.1 +type DeviceAuthorizationHandler struct { + DeviceCodeStorage DeviceCodeStorage + UserCodeStorage UserCodeStorage + DeviceCodeStrategy DeviceCodeStrategy + UserCodeStrategy UserCodeStrategy + Config fosite.Configurator +} + +func (d *DeviceAuthorizationHandler) HandleDeviceAuthorizeEndpointRequest(ctx context.Context, dar fosite.Requester, resp fosite.DeviceAuthorizeResponder) error { + fmt.Println("DeviceAuthorizationHandler :: HandleDeviceAuthorizeEndpointRequest ++") + deviceCode, err := d.DeviceCodeStrategy.GenerateDeviceCode() + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + userCode, err := d.UserCodeStrategy.GenerateUserCode() + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + fmt.Println("DeviceAuthorizationHandler :: HandleDeviceAuthorizeEndpointRequest +++") + + userCodeSignature := d.UserCodeStrategy.UserCodeSignature(ctx, userCode) + deviceCodeSignature := d.DeviceCodeStrategy.DeviceCodeSignature(ctx, deviceCode) + + // Set User Code expiry time + dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second)) + dar.SetID(deviceCodeSignature) + + fmt.Println("DeviceAuthorizationHandler :: HandleDeviceAuthorizeEndpointRequest ++++") + + // Store the User Code session (this has no real data other that the uer and device code), can be converted into a 'full' session after user auth + if err := d.UserCodeStorage.CreateUserCodeSession(ctx, userCodeSignature, dar); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + fmt.Println("DeviceAuthorizationHandler :: HandleDeviceAuthorizeEndpointRequest +++++") + + // Populate the response fields + resp.SetDeviceCode(deviceCode) + resp.SetUserCode(userCode) + resp.SetVerificationURI(d.Config.GetDeviceVerificationURL(ctx)) + resp.SetVerificationURIComplete(d.Config.GetDeviceVerificationURL(ctx) + "?user_code=" + userCode) + resp.SetExpiresIn(int64(time.Until(dar.GetSession().GetExpiresAt(fosite.UserCode)).Seconds())) + resp.SetInterval(int(d.Config.GetDeviceAuthTokenPollingInterval(ctx).Seconds())) + return nil +} diff --git a/handler/oauth2/device_authorization_test.go b/handler/oauth2/device_authorization_test.go new file mode 100644 index 000000000..fe6395ee1 --- /dev/null +++ b/handler/oauth2/device_authorization_test.go @@ -0,0 +1,52 @@ +package oauth2 + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/ory/fosite" + "github.com/ory/fosite/storage" + "github.com/stretchr/testify/assert" +) + +func Test_HandleDeviceAuthorizeEndpointRequest(t *testing.T) { + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + deviceStore := storage.NewMemoryStore() + userStore := storage.NewMemoryStore() + handler := DeviceAuthorizationHandler{ + DeviceCodeStorage: deviceStore, + UserCodeStorage: userStore, + DeviceCodeStrategy: hmacshaStrategy, + UserCodeStrategy: hmacshaStrategy, + Config: &fosite.Config{ + DeviceAndUserCodeLifespan: time.Minute * 10, + DeviceAuthTokenPollingInterval: time.Second * 10, + DeviceVerificationURL: "www.test.com", + AccessTokenLifespan: time.Hour, + RefreshTokenLifespan: time.Hour, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + RefreshTokenScopes: []string{"offline"}, + }, + } + + req := &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"code"}, + Request: fosite.Request{ + Session: &fosite.DefaultSession{}, + }, + } + resp := fosite.NewDeviceAuthorizeResponse() + + handler.HandleDeviceAuthorizeEndpointRequest(nil, req, resp) + + assert.NotEmpty(t, resp.GetDeviceCode()) + assert.NotEmpty(t, resp.GetUserCode()) + assert.Equal(t, len(resp.GetUserCode()), 8) + assert.Equal(t, len(resp.GetDeviceCode()), 100) + assert.Equal(t, resp.GetVerificationURI(), "www.test.com") + +} diff --git a/handler/oauth2/flow_authorize_code_auth_test.go b/handler/oauth2/flow_authorize_code_auth_test.go index be0f1cc8d..c95fd5321 100644 --- a/handler/oauth2/flow_authorize_code_auth_test.go +++ b/handler/oauth2/flow_authorize_code_auth_test.go @@ -40,7 +40,7 @@ func parseUrl(uu string) *url.URL { } func TestAuthorizeCode_HandleAuthorizeEndpointRequest(t *testing.T) { - for k, strategy := range map[string]CoreStrategy{ + for k, strategy := range map[string]AuthorizeCodeStrategy{ "hmac": &hmacshaStrategy, } { t.Run("strategy="+k, func(t *testing.T) { diff --git a/handler/oauth2/flow_authorize_code_token_test.go b/handler/oauth2/flow_authorize_code_token_test.go index 326d617c9..3bc754669 100644 --- a/handler/oauth2/flow_authorize_code_token_test.go +++ b/handler/oauth2/flow_authorize_code_token_test.go @@ -258,7 +258,7 @@ func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) { } func TestAuthorizeCode_HandleTokenEndpointRequest(t *testing.T) { - for k, strategy := range map[string]CoreStrategy{ + for k, strategy := range map[string]AuthorizeCodeStrategy{ "hmac": &hmacshaStrategy, } { t.Run("strategy="+k, func(t *testing.T) { diff --git a/handler/oauth2/flow_device_code_auth.go b/handler/oauth2/flow_device_code_auth.go new file mode 100644 index 000000000..6ebac1b19 --- /dev/null +++ b/handler/oauth2/flow_device_code_auth.go @@ -0,0 +1,61 @@ +package oauth2 + +import ( + "context" + "fmt" + + "github.com/ory/fosite" + "github.com/ory/x/errorsx" +) + +type AuthorizeDeviceGrantTypeHandler struct { + CoreStorage CoreStorage + DeviceCodeStrategy DeviceCodeStrategy + UserCodeStrategy UserCodeStrategy + AccessTokenStrategy AccessTokenStrategy + RefreshTokenStrategy RefreshTokenStrategy + AuthorizeCodeStrategy AuthorizeCodeStrategy + Config fosite.Configurator +} + +func (c *AuthorizeDeviceGrantTypeHandler) HandleAuthorizeEndpointRequest(ctx context.Context, ar fosite.AuthorizeRequester, resp fosite.AuthorizeResponder) error { + + if !ar.GetResponseTypes().ExactOne("device_code") { + return nil + } + + if !ar.GetClient().GetGrantTypes().Has("urn:ietf:params:oauth:grant-type:device_code") { + return nil + } + + resp.AddParameter("state", ar.GetState()) + + userCode := ar.GetRequestForm().Get("user_code") + userCodeSignature := c.UserCodeStrategy.UserCodeSignature(ctx, userCode) + + session, err := c.CoreStorage.GetUserCodeSession(ctx, userCodeSignature, fosite.NewRequest().Session) + if err != nil { + return err + } + + fmt.Println("SUBJECT : " + ar.GetSession().GetSubject()) + + if session.GetClient().GetID() != ar.GetClient().GetID() { + return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("The OAuth 2.0 Client ID from this request does not match the one from the authorize request.")) + } + + /* + expires := session.GetSession().GetExpiresAt(fosite.UserCode) + if time.Now().UTC().After(expires) { + return errorsx.WithStack(fosite.ErrTokenExpired) + }*/ + + // session.GetID() is the HMAC signature of the device code generated in the inital request + err = c.CoreStorage.CreateDeviceCodeSession(ctx, session.GetID(), ar) + if err != nil { + return errorsx.WithStack(err) + } + + ar.SetResponseTypeHandled("device_code") + return nil +} diff --git a/handler/oauth2/flow_device_code_auth_test.go b/handler/oauth2/flow_device_code_auth_test.go new file mode 100644 index 000000000..842b4c035 --- /dev/null +++ b/handler/oauth2/flow_device_code_auth_test.go @@ -0,0 +1,200 @@ +package oauth2 + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/storage" + "github.com/stretchr/testify/require" +) + +func TestAuthorizeCode_HandleDeviceAuthorizeEndpointRequest(t *testing.T) { + + for k, strategy := range map[string]CoreStrategy{ + "hmac": &hmacshaStrategy, + } { + t.Run("strategy="+k, func(t *testing.T) { + store := storage.NewMemoryStore() + handler := AuthorizeDeviceGrantTypeHandler{ + CoreStorage: store, + DeviceCodeStrategy: hmacshaStrategy, + UserCodeStrategy: hmacshaStrategy, + AccessTokenStrategy: strategy, + RefreshTokenStrategy: strategy, + AuthorizeCodeStrategy: strategy, + Config: &fosite.Config{ + DeviceAndUserCodeLifespan: time.Minute * 10, + DeviceAuthTokenPollingInterval: time.Second * 10, + DeviceVerificationURL: "localhost", + AccessTokenLifespan: time.Hour, + RefreshTokenLifespan: time.Hour, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + RefreshTokenScopes: []string{"offline"}, + }, + } + for _, c := range []struct { + handler AuthorizeDeviceGrantTypeHandler + areq *fosite.AuthorizeRequest + breq *fosite.AuthorizeRequest + expire time.Duration + description string + expectErr error + expect func(t *testing.T, areq *fosite.AuthorizeRequest, aresp *fosite.AuthorizeResponse) + }{ + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{""}, + Request: *fosite.NewRequest(), + }, + breq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{""}, + Request: *fosite.NewRequest(), + }, + description: "should pass because not responsible for handling an empty response type", + expire: time.Minute * 10, + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"foo"}, + Request: *fosite.NewRequest(), + }, + breq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{""}, + Request: *fosite.NewRequest(), + }, + description: "should pass because not responsible for handling an invalid response type", + expire: time.Minute * 10, + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: "Default", + GrantTypes: fosite.Arguments{"code"}, + }, + }, + }, + breq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: "Default", + GrantTypes: fosite.Arguments{"code"}, + }, + }, + }, + description: "should pass because not responsible for handling an invalid grant type", + expire: time.Minute * 10, + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: "Default", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + }, + Form: url.Values{"user_code": {"ABC123"}}, + }, + }, + breq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: "Default", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + }, + Form: url.Values{"user_code": {"ABC123"}}, + }, + }, + description: "should pass as session and request have matching client id", + expire: time.Minute * 10, + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: "Default", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + }, + Form: url.Values{"user_code": {"ABC123"}}, + }, + }, + breq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: "Broken", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + }, + Form: url.Values{"user_code": {"ABC123"}}, + }, + }, + description: "should fail due to a missmatch in session and request ClientID", + expire: time.Minute * 10, + expectErr: fosite.ErrInvalidGrant, + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: "Default", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + }, + Form: url.Values{"user_code": {"ABC123"}}, + }, + }, + breq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: "Default", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + }, + Form: url.Values{"user_code": {"ABC123"}}, + }, + }, + description: "should fail due to expired user session", + expire: -(time.Minute * 10), + //expectErr: fosite.ErrTokenExpired, + }, + } { + t.Run("case="+c.description, func(t *testing.T) { + + c.areq.SetID("ID1") + c.areq.Session = &fosite.DefaultSession{Subject: "A"} + c.breq.Session = &fosite.DefaultSession{Subject: "A"} + expireAt := time.Now().UTC().Add(c.expire) + c.areq.Session.SetExpiresAt(fosite.UserCode, expireAt) + userCodeSig := hmacshaStrategy.UserCodeSignature(context.Background(), c.areq.Form.Get("user_code")) + store.CreateUserCodeSession(nil, userCodeSig, c.areq) + + aresp := fosite.NewAuthorizeResponse() + err := c.handler.HandleAuthorizeEndpointRequest(nil, c.breq, aresp) + if c.expectErr != nil { + require.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + } + + if c.expect != nil { + c.expect(t, c.areq, aresp) + } + }) + } + }) + } +} diff --git a/handler/oauth2/flow_device_code_token.go b/handler/oauth2/flow_device_code_token.go new file mode 100644 index 000000000..b349b8de8 --- /dev/null +++ b/handler/oauth2/flow_device_code_token.go @@ -0,0 +1,163 @@ +package oauth2 + +import ( + "context" + "fmt" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/storage" + "github.com/ory/x/errorsx" +) + +const deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code" + +func (d *AuthorizeDeviceGrantTypeHandler) HandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) error { + + if !d.CanHandleTokenEndpointRequest(ctx, requester) { + return errorsx.WithStack(errorsx.WithStack(fosite.ErrUnknownRequest)) + } + + if !requester.GetClient().GetGrantTypes().Has(deviceCodeGrantType) { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("The OAuth 2.0 Client is not allowed to use authorization grant \"" + deviceCodeGrantType + "\".")) + } + + code := requester.GetRequestForm().Get("device_code") + if code == "" { + return errorsx.WithStack(errorsx.WithStack(fosite.ErrUnknownRequest.WithHint("device_code missing form body"))) + } + codeSignature := d.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) + + // Get the device code session to validate based on HMAC of the device code supplied + session, err := d.CoreStorage.GetDeviceCodeSession(ctx, codeSignature, requester.GetSession()) + + if err != nil { + return errorsx.WithStack(fosite.ErrAuthorizationPending) + } + + requester.SetRequestedScopes(session.GetRequestedScopes()) + requester.SetRequestedAudience(session.GetRequestedAudience()) + + if requester.GetClient().GetID() != session.GetClient().GetID() { + return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("The OAuth 2.0 Client ID from this request does not match the one from the authorize request.")) + } + + //expires := session.GetSession().GetExpiresAt(fosite.UserCode) + //if time.Now().UTC().After(expires) { + // return errorsx.WithStack(fosite.ErrTokenExpired) + //} + + //requester.SetSession(session.GetSession()) + //requester.SetID(session.GetID()) + + atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeAuthorizationCode, fosite.AccessToken, d.Config.GetAccessTokenLifespan(ctx)) + requester.GetSession().SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) + + rtLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeAuthorizationCode, fosite.RefreshToken, d.Config.GetRefreshTokenLifespan(ctx)) + if rtLifespan > -1 { + requester.GetSession().SetExpiresAt(fosite.RefreshToken, time.Now().UTC().Add(rtLifespan).Round(time.Second)) + } + + return nil +} + +func (d *AuthorizeDeviceGrantTypeHandler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { + return true +} + +func (d *AuthorizeDeviceGrantTypeHandler) CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool { + fmt.Println("CanHandleTokenEndpointRequest OAUTH") + return requester.GetGrantTypes().ExactOne(deviceCodeGrantType) +} + +func (d *AuthorizeDeviceGrantTypeHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { + + if !d.CanHandleTokenEndpointRequest(ctx, requester) { + return errorsx.WithStack(fosite.ErrUnknownRequest) + } + + code := requester.GetRequestForm().Get("device_code") + if code == "" { + return errorsx.WithStack(errorsx.WithStack(fosite.ErrUnknownRequest.WithHint("device_code missing form body"))) + } + codeSignature := d.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) + + if err := d.DeviceCodeStrategy.ValidateDeviceCode(ctx, requester, code); err != nil { + // This needs to happen after store retrieval for the session to be hydrated properly + return err + } + + // Get the device code session ready for exchange to auth / refresh / oidc sessions + session, err := d.CoreStorage.GetDeviceCodeSession(ctx, codeSignature, requester.GetSession()) + + if err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithWrap(err).WithDebug(err.Error())) + } + + for _, scope := range session.GetGrantedScopes() { + requester.GrantScope(scope) + } + + for _, audience := range session.GetGrantedAudience() { + requester.GrantAudience(audience) + } + + access, accessSignature, err := d.AccessTokenStrategy.GenerateAccessToken(ctx, requester) + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + var refresh, refreshSignature string + + if d.canIssueRefreshToken(ctx, requester) { + refresh, refreshSignature, err = d.RefreshTokenStrategy.GenerateRefreshToken(ctx, requester) + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + } + + ctx, err = storage.MaybeBeginTx(ctx, d.CoreStorage) + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + defer func() { + if err != nil { + if rollBackTxnErr := storage.MaybeRollbackTx(ctx, d.CoreStorage); rollBackTxnErr != nil { + err = errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) + } + } + }() + + if err = d.CoreStorage.DeleteDeviceCodeSession(ctx, code); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } else if err = d.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } else if refreshSignature != "" { + if err = d.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + } + + responder.SetAccessToken(access) + responder.SetTokenType("bearer") + atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeAuthorizationCode, fosite.AccessToken, d.Config.GetAccessTokenLifespan(ctx)) + responder.SetExpiresIn(getExpiresIn(requester, fosite.AccessToken, atLifespan, time.Now().UTC())) + responder.SetScopes(requester.GetGrantedScopes()) + if refresh != "" { + responder.SetExtra("refresh_token", refresh) + } + + if err = storage.MaybeCommitTx(ctx, d.CoreStorage); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + return nil +} + +func (c *AuthorizeDeviceGrantTypeHandler) canIssueRefreshToken(ctx context.Context, request fosite.Requester) bool { + // Require one of the refresh token scopes, if set. + if len(c.Config.GetRefreshTokenScopes(ctx)) > 0 && !request.GetGrantedScopes().HasOneOf(c.Config.GetRefreshTokenScopes(ctx)...) { + return false + } + return true +} diff --git a/handler/oauth2/flow_device_code_token_test.go b/handler/oauth2/flow_device_code_token_test.go new file mode 100644 index 000000000..cb23f6243 --- /dev/null +++ b/handler/oauth2/flow_device_code_token_test.go @@ -0,0 +1,350 @@ +package oauth2 + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/storage" + "github.com/stretchr/testify/require" +) + +func TestAuthorizeCode_HandleDeviceTokenEndpointRequest(t *testing.T) { + + for k, strategy := range map[string]CoreStrategy{ + "hmac": &hmacshaStrategy, + } { + t.Run("strategy="+k, func(t *testing.T) { + store := storage.NewMemoryStore() + handler := AuthorizeDeviceGrantTypeHandler{ + CoreStorage: store, + DeviceCodeStrategy: hmacshaStrategy, + UserCodeStrategy: hmacshaStrategy, + AccessTokenStrategy: strategy, + RefreshTokenStrategy: strategy, + AuthorizeCodeStrategy: strategy, + Config: &fosite.Config{ + DeviceAndUserCodeLifespan: time.Minute * 10, + DeviceAuthTokenPollingInterval: time.Second * 10, + DeviceVerificationURL: "localhost", + AccessTokenLifespan: time.Hour, + RefreshTokenLifespan: time.Hour, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + RefreshTokenScopes: []string{"offline"}, + }, + } + for _, c := range []struct { + handler AuthorizeDeviceGrantTypeHandler + areq *fosite.AccessRequest + breq *fosite.AccessRequest + createDeviceSession bool + expire time.Duration + description string + expectErr error + expect func(t *testing.T, areq *fosite.AccessRequest) + }{ + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"authorization_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + }, + }, + breq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"authorization_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + }, + }, + description: "Should fail due to wrong grant type", + expectErr: fosite.ErrUnknownRequest, + createDeviceSession: false, + expire: time.Minute * 10, + }, + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + }, + }, + breq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + }, + }, + description: "Should fail due to no device_code supplied", + expectErr: fosite.ErrUnauthorizedClient, + createDeviceSession: false, + expire: time.Minute * 10, + }, + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + breq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + description: "Should fail due to no user_code session available", + expectErr: fosite.ErrUnauthorizedClient, + createDeviceSession: false, + expire: time.Minute * 10, + }, + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + breq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + description: "Should pass as device_code form data and session are available", + createDeviceSession: true, + expire: time.Minute * 10, + }, + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + breq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + description: "Should fail as session expired", + createDeviceSession: true, + expire: -(time.Minute * 10), + expectErr: fosite.ErrUnauthorizedClient, + }, + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + breq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "bar", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{Subject: "A"}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + description: "Should fail as session and request clients do not match", + createDeviceSession: true, + expire: time.Minute * 10, + expectErr: fosite.ErrUnauthorizedClient, + }, + } { + t.Run("case="+c.description, func(t *testing.T) { + + if c.createDeviceSession { + c.areq.SetID("ID1") + c.areq.Session = &fosite.DefaultSession{} + expireAt := time.Now().UTC().Add(c.expire) + c.areq.Session.SetExpiresAt(fosite.UserCode, expireAt) + deviceSignature := hmacshaStrategy.DeviceCodeSignature(context.Background(), c.areq.Form.Get("device_code")) + store.CreateDeviceCodeSession(nil, deviceSignature, c.areq) + } + + err := c.handler.HandleTokenEndpointRequest(nil, c.breq) + if c.expectErr != nil { + require.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + } + + if c.expect != nil { + c.expect(t, c.areq) + } + }) + } + }) + } +} + +func TestAuthorizeCode_PopulateDeviceTokenEndpointResponse(t *testing.T) { + + for k, strategy := range map[string]CoreStrategy{ + "hmac": &hmacshaStrategy, + } { + t.Run("strategy="+k, func(t *testing.T) { + store := storage.NewMemoryStore() + handler := AuthorizeDeviceGrantTypeHandler{ + CoreStorage: store, + DeviceCodeStrategy: hmacshaStrategy, + UserCodeStrategy: hmacshaStrategy, + AccessTokenStrategy: strategy, + RefreshTokenStrategy: strategy, + AuthorizeCodeStrategy: strategy, + Config: &fosite.Config{ + DeviceAndUserCodeLifespan: time.Minute * 10, + DeviceAuthTokenPollingInterval: time.Second * 10, + DeviceVerificationURL: "localhost", + AccessTokenLifespan: time.Hour, + RefreshTokenLifespan: time.Hour, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + RefreshTokenScopes: []string{"offline"}, + }, + } + for _, c := range []struct { + handler AuthorizeDeviceGrantTypeHandler + areq *fosite.AccessRequest + createDeviceSession bool + description string + expectErr error + expect func(t *testing.T, areq *fosite.AccessRequest) + }{ + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"authorization_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), + }, + }, + description: "Should fail due to wrong grant type", + expectErr: fosite.ErrUnknownRequest, + createDeviceSession: false, + }, { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), + }, + }, + description: "Should fail due to no device_code supplied", + expectErr: fosite.ErrUnknownRequest, + createDeviceSession: false, + }, + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{""}}, + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + description: "Should fail due to no user_code session available", + expectErr: fosite.ErrInvalidRequest, + createDeviceSession: false, + }, + { + handler: handler, + areq: &fosite.AccessRequest{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}}, + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), + GrantedScope: fosite.Arguments{"openid", "offline"}, + GrantedAudience: fosite.Arguments{"www.websitesite.com"}, + Form: url.Values{"device_code": {"ABC1234"}}, + }, + }, + description: "Should pass as device_code form data and session are available", + createDeviceSession: true, + }, + } { + t.Run("case="+c.description, func(t *testing.T) { + + c.areq.GetSession().SetExpiresAt(fosite.UserCode, time.Now().Add(time.Minute*5)) + if c.createDeviceSession { + c.areq.SetID("ID1") + deviceSig := hmacshaStrategy.DeviceCodeSignature(context.TODO(), c.areq.Form.Get("device_code")) + store.CreateDeviceCodeSession(nil, deviceSig, c.areq) + } + + resp := fosite.NewAccessResponse() + err := c.handler.PopulateTokenEndpointResponse(nil, c.areq, resp) + if c.expectErr != nil { + require.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + } + + accessToken := resp.GetAccessToken() + refreshToken := resp.GetExtra("refresh_token") + + // Make sure we only create tokens if we have a device session available + if c.createDeviceSession { + require.NotEmpty(t, accessToken) + require.NotEmpty(t, refreshToken) + } else { + require.Empty(t, accessToken) + require.Empty(t, refreshToken) + } + + if c.expect != nil { + c.expect(t, c.areq) + } + }) + } + }) + } +} diff --git a/handler/oauth2/storage.go b/handler/oauth2/storage.go index 550211993..a72002564 100644 --- a/handler/oauth2/storage.go +++ b/handler/oauth2/storage.go @@ -31,6 +31,8 @@ type CoreStorage interface { AuthorizeCodeStorage AccessTokenStorage RefreshTokenStorage + DeviceCodeStorage + UserCodeStorage } // AuthorizeCodeStorage handles storage requests related to authorization codes. @@ -66,3 +68,14 @@ type RefreshTokenStorage interface { DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) } + +type DeviceCodeStorage interface { + CreateDeviceCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error) + GetDeviceCodeSession(ctx context.Context, code string, session fosite.Session) (request fosite.Requester, err error) + DeleteDeviceCodeSession(ctx context.Context, code string) (err error) +} +type UserCodeStorage interface { + CreateUserCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error) + GetUserCodeSession(ctx context.Context, code string, session fosite.Session) (request fosite.Requester, err error) + DeleteUserCodeSession(ctx context.Context, code string) (err error) +} diff --git a/handler/oauth2/strategy.go b/handler/oauth2/strategy.go index 8a469c8c4..26107f731 100644 --- a/handler/oauth2/strategy.go +++ b/handler/oauth2/strategy.go @@ -31,6 +31,8 @@ type CoreStrategy interface { AccessTokenStrategy RefreshTokenStrategy AuthorizeCodeStrategy + DeviceCodeStrategy + UserCodeStrategy } type AccessTokenStrategy interface { @@ -50,3 +52,15 @@ type AuthorizeCodeStrategy interface { GenerateAuthorizeCode(ctx context.Context, requester fosite.Requester) (token string, signature string, err error) ValidateAuthorizeCode(ctx context.Context, requester fosite.Requester, token string) (err error) } + +type DeviceCodeStrategy interface { + DeviceCodeSignature(context context.Context, code string) string + ValidateDeviceCode(context context.Context, r fosite.Requester, code string) (err error) + GenerateDeviceCode() (code string, err error) +} + +type UserCodeStrategy interface { + UserCodeSignature(context context.Context, code string) string + ValidateUserCode(context context.Context, r fosite.Requester, code string) (err error) + GenerateUserCode() (code string, err error) +} diff --git a/handler/oauth2/strategy_hmacsha.go b/handler/oauth2/strategy_hmacsha.go index ad51bf44a..6b722ad6a 100644 --- a/handler/oauth2/strategy_hmacsha.go +++ b/handler/oauth2/strategy_hmacsha.go @@ -23,7 +23,9 @@ package oauth2 import ( "context" + "crypto/rand" "fmt" + "math/big" "strings" "time" @@ -39,6 +41,7 @@ type HMACSHAStrategy struct { fosite.AccessTokenLifespanProvider fosite.RefreshTokenLifespanProvider fosite.AuthorizeCodeLifespanProvider + fosite.DeviceAndUserCodeLifespanProvider } prefix *string } @@ -138,3 +141,56 @@ func (h *HMACSHAStrategy) ValidateAuthorizeCode(ctx context.Context, r fosite.Re return h.Enigma.Validate(ctx, h.trimPrefix(token, "ac")) } + +func (h HMACSHAStrategy) generateRandomString(length int) (token string, err error) { + chars := [20]byte{'B', 'C', 'D', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Z'} + chars_length := int64(len(chars)) + + code := make([]byte, length) + for i := 0; i < length; i++ { + num, err := rand.Int(rand.Reader, big.NewInt(chars_length)) + if err != nil { + return "", err + } + code[i] = chars[num.Int64()] + } + return string(code), nil +} + +func (h HMACSHAStrategy) GenerateUserCode() (token string, err error) { + return h.generateRandomString(8) +} + +func (h HMACSHAStrategy) UserCodeSignature(ctx context.Context, token string) string { + return h.Enigma.GenerateHMACForString(token, ctx) +} + +func (h HMACSHAStrategy) ValidateUserCode(ctx context.Context, r fosite.Requester, code string) (err error) { + var exp = r.GetSession().GetExpiresAt(fosite.UserCode) + if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)).Before(time.Now().UTC()) { + return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Access token expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)))) + } + if !exp.IsZero() && exp.Before(time.Now().UTC()) { + return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Access token expired at '%s'.", exp)) + } + return nil +} + +func (h HMACSHAStrategy) GenerateDeviceCode() (token string, err error) { + return h.generateRandomString(100) +} + +func (h HMACSHAStrategy) DeviceCodeSignature(ctx context.Context, token string) string { + return h.Enigma.GenerateHMACForString(token, ctx) +} + +func (h HMACSHAStrategy) ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) (err error) { + var exp = r.GetSession().GetExpiresAt(fosite.UserCode) + if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)).Before(time.Now().UTC()) { + return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("1 Device code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)))) + } + if !exp.IsZero() && exp.Before(time.Now().UTC()) { + return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("2 Device code expired at '%s'.", exp)) + } + return nil +} diff --git a/handler/oauth2/strategy_hmacsha_test.go b/handler/oauth2/strategy_hmacsha_test.go index ee232e608..3a5c1653d 100644 --- a/handler/oauth2/strategy_hmacsha_test.go +++ b/handler/oauth2/strategy_hmacsha_test.go @@ -36,8 +36,9 @@ import ( var hmacshaStrategy = HMACSHAStrategy{ Enigma: &hmac.HMACStrategy{Config: &fosite.Config{GlobalSecret: []byte("foobarfoobarfoobarfoobarfoobarfoobarfoobarfoobar")}}, Config: &fosite.Config{ - AccessTokenLifespan: time.Hour * 24, - AuthorizeCodeLifespan: time.Hour * 24, + AccessTokenLifespan: time.Hour * 24, + AuthorizeCodeLifespan: time.Hour * 24, + DeviceAndUserCodeLifespan: time.Hour * 24, }, } diff --git a/handler/oauth2/strategy_jwt.go b/handler/oauth2/strategy_jwt.go index 8ce1854c6..85bd20e00 100644 --- a/handler/oauth2/strategy_jwt.go +++ b/handler/oauth2/strategy_jwt.go @@ -149,3 +149,27 @@ func (h *DefaultJWTStrategy) generate(ctx context.Context, tokenType fosite.Toke return h.Signer.Generate(ctx, claims.ToMapClaims(), jwtSession.GetJWTHeader()) } } + +func (h DefaultJWTStrategy) DeviceCodeSignature(ctx context.Context, token string) string { + return h.HMACSHAStrategy.DeviceCodeSignature(ctx, token) +} + +func (h *DefaultJWTStrategy) GenerateDeviceCode() (token string, err error) { + return h.HMACSHAStrategy.GenerateDeviceCode() +} + +func (h *DefaultJWTStrategy) ValidateDeviceCode(context context.Context, r fosite.Requester, code string) (err error) { + return h.HMACSHAStrategy.ValidateDeviceCode(context, r, code) +} + +func (h DefaultJWTStrategy) UserCodeSignature(ctx context.Context, token string) string { + return h.HMACSHAStrategy.UserCodeSignature(ctx, token) +} + +func (h *DefaultJWTStrategy) GenerateUserCode() (token string, err error) { + return h.HMACSHAStrategy.GenerateUserCode() +} + +func (h *DefaultJWTStrategy) ValidateUserCode(context context.Context, r fosite.Requester, code string) (err error) { + return h.HMACSHAStrategy.ValidateUserCode(context, r, code) +} diff --git a/handler/openid/flow_device_auth.go b/handler/openid/flow_device_auth.go new file mode 100644 index 000000000..a038f698d --- /dev/null +++ b/handler/openid/flow_device_auth.go @@ -0,0 +1,61 @@ +package openid + +import ( + "context" + + "github.com/ory/fosite/handler/oauth2" + "github.com/ory/x/errorsx" + + "github.com/ory/fosite" +) + +type OpenIDConnectDeviceHandler struct { + CoreStorage oauth2.CoreStorage + DeviceCodeStrategy oauth2.DeviceCodeStrategy + UserCodeStrategy oauth2.UserCodeStrategy + + OpenIDConnectRequestStorage OpenIDConnectRequestStorage + OpenIDConnectRequestValidator *OpenIDConnectRequestValidator + + Config fosite.Configurator + + *IDTokenHandleHelper +} + +func (c *OpenIDConnectDeviceHandler) HandleAuthorizeEndpointRequest(ctx context.Context, ar fosite.AuthorizeRequester, resp fosite.AuthorizeResponder) error { + if !(ar.GetGrantedScopes().Has("openid") && ar.GetResponseTypes().ExactOne("device_code")) { + return nil + } + + if !ar.GetClient().GetGrantTypes().Has("urn:ietf:params:oauth:grant-type:device_code") { + return nil + } + + userCode := ar.GetRequestForm().Get("user_code") + userCodeSignature := c.UserCodeStrategy.UserCodeSignature(ctx, userCode) + + userSession, err := c.CoreStorage.GetUserCodeSession(ctx, userCodeSignature, fosite.NewRequest().Session) + if err != nil { + return errorsx.WithStack(fosite.ErrNotFound.WithDebug("User session not found.")) + } + + deviceSession, err := c.CoreStorage.GetDeviceCodeSession(ctx, userSession.GetID(), fosite.NewRequest().Session) + if err != nil { + return errorsx.WithStack(fosite.ErrNotFound.WithDebug("The devicve code has not been issued yet.")) + } + + if len(deviceSession.GetID()) == 0 { + return errorsx.WithStack(fosite.ErrMisconfiguration.WithDebug("The devicve code has not been issued yet, indicating a broken code configuration.")) + } + + if err := c.OpenIDConnectRequestValidator.ValidatePrompt(ctx, ar); err != nil { + return err + } + + // The device code is stored in the ID field of the requester, use this to build the OpenID session as the token endpoint will not know about the user_code + if err := c.OpenIDConnectRequestStorage.CreateOpenIDConnectSession(ctx, userSession.GetID(), ar.Sanitize(oidcParameters)); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + return nil +} diff --git a/handler/openid/flow_device_token.go b/handler/openid/flow_device_token.go new file mode 100644 index 000000000..ecbb7c5d5 --- /dev/null +++ b/handler/openid/flow_device_token.go @@ -0,0 +1,75 @@ +package openid + +import ( + "context" + "fmt" + + "github.com/ory/x/errorsx" + + "github.com/pkg/errors" + + "github.com/ory/fosite" +) + +func (c *OpenIDConnectDeviceHandler) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { + return errorsx.WithStack(fosite.ErrUnknownRequest) +} + +func (c *OpenIDConnectDeviceHandler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { + + if !c.CanHandleTokenEndpointRequest(ctx, requester) { + return errorsx.WithStack(fosite.ErrUnknownRequest) + } + + code := requester.GetRequestForm().Get("device_code") + if code == "" { + return errorsx.WithStack(errorsx.WithStack(fosite.ErrUnknownRequest.WithHint("device_code missing form body"))) + } + codeSignature := c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) + + authorize, err := c.OpenIDConnectRequestStorage.GetOpenIDConnectSession(ctx, codeSignature, requester) + if errors.Is(err, ErrNoSessionFound) { + return errorsx.WithStack(fosite.ErrUnknownRequest.WithWrap(err).WithDebug(err.Error())) + } else if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + if !authorize.GetGrantedScopes().Has("openid") { + return errorsx.WithStack(fosite.ErrMisconfiguration.WithDebug("An OpenID Connect session was found but the openid scope is missing, probably due to a broken code configuration.")) + } + + if !requester.GetClient().GetGrantTypes().Has("urn:ietf:params:oauth:grant-type:device_code") { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("The OAuth 2.0 Client is not allowed to use the authorization grant \"urn:ietf:params:oauth:grant-type:device_code\".")) + } + + sess, ok := requester.GetSession().(Session) + if !ok { + return errorsx.WithStack(fosite.ErrServerError.WithDebug("Failed to generate id token because session must be of type fosite/handler/openid.Session.")) + } + + claims := sess.IDTokenClaims() + if claims.Subject == "" { + return errorsx.WithStack(fosite.ErrServerError.WithDebug("Failed to generate id token because subject is an empty string.")) + } + + claims.AccessTokenHash = c.GetAccessTokenHash(ctx, requester, responder) + + // The response type `id_token` is only required when performing the implicit or hybrid flow, see: + // https://openid.net/specs/openid-connect-registration-1_0.html + // + // if !requester.GetClient().GetResponseTypes().Has("id_token") { + // return errorsx.WithStack(fosite.ErrInvalidGrant.WithDebug("The client is not allowed to use response type id_token")) + // } + + idTokenLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeAuthorizationCode, fosite.IDToken, c.Config.GetIDTokenLifespan(ctx)) + return c.IssueExplicitIDToken(ctx, idTokenLifespan, authorize, responder) +} + +func (c *OpenIDConnectDeviceHandler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { + return false +} + +func (c *OpenIDConnectDeviceHandler) CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool { + fmt.Println("CanHandleTokenEndpointRequest OIDC") + return requester.GetGrantTypes().ExactOne("urn:ietf:params:oauth:grant-type:device_code") +} diff --git a/handler/openid/flow_explicit_token.go b/handler/openid/flow_explicit_token.go index c7a16dcc5..2e19e1f54 100644 --- a/handler/openid/flow_explicit_token.go +++ b/handler/openid/flow_explicit_token.go @@ -23,6 +23,7 @@ package openid import ( "context" + "fmt" "github.com/ory/x/errorsx" @@ -83,5 +84,6 @@ func (c *OpenIDConnectExplicitHandler) CanSkipClientAuth(ctx context.Context, re } func (c *OpenIDConnectExplicitHandler) CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool { + fmt.Println("CanHandleTokenEndpointRequest EXPL TOKEN") return requester.GetGrantTypes().ExactOne("authorization_code") } diff --git a/handler/pkce/handler_device.go b/handler/pkce/handler_device.go new file mode 100644 index 000000000..82c81e083 --- /dev/null +++ b/handler/pkce/handler_device.go @@ -0,0 +1,228 @@ +package pkce + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + + "github.com/ory/x/errorsx" + + "github.com/pkg/errors" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/oauth2" +) + +type HandlerDevice struct { + CoreStorage oauth2.CoreStorage + DeviceCodeStrategy oauth2.DeviceCodeStrategy + UserCodeStrategy oauth2.UserCodeStrategy + AuthorizeCodeStrategy oauth2.AuthorizeCodeStrategy + Storage PKCERequestStorage + Config fosite.Configurator +} + +func (c *HandlerDevice) HandleDeviceAuthorizeEndpointRequest(ctx context.Context, ar fosite.Requester, resp fosite.DeviceAuthorizeResponder) error { + fmt.Println("HandlerDevice :: HandleDeviceAuthorizeEndpointRequest ++") + + if !ar.GetClient().GetGrantTypes().Has("urn:ietf:params:oauth:grant-type:device_code") { + return nil + } + + challenge := ar.GetRequestForm().Get("code_challenge") + method := ar.GetRequestForm().Get("code_challenge_method") + client := ar.GetClient() + userCode := resp.GetUserCode() + + userCodeSignature := c.UserCodeStrategy.UserCodeSignature(ctx, userCode) + + session, err := c.CoreStorage.GetUserCodeSession(ctx, userCodeSignature, fosite.NewRequest().Session) + if err != nil { + return err + } + + if err := c.validate(ctx, challenge, method, client); err != nil { + return err + } + + fmt.Println("PKCE ID : " + session.GetID()) + + if err := c.Storage.CreatePKCERequestSession(ctx, session.GetID(), ar.Sanitize([]string{ + "code_challenge", + "code_challenge_method", + })); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + fmt.Println(resp) + + return nil +} + +func (c *HandlerDevice) validate(ctx context.Context, challenge string, method string, client fosite.Client) error { + if challenge == "" { + // If the server requires Proof Key for Code Exchange (PKCE) by OAuth + // clients and the client does not send the "code_challenge" in + // the request, the authorization endpoint MUST return the authorization + // error response with the "error" value set to "invalid_request". The + // "error_description" or the response of "error_uri" SHOULD explain the + // nature of error, e.g., code challenge required. + + if c.Config.GetEnforcePKCE(ctx) { + return errorsx.WithStack(fosite.ErrInvalidRequest. + WithHint("Clients must include a code_challenge when performing the authorize code flow, but it is missing."). + WithDebug("The server is configured in a way that enforces PKCE for clients.")) + } + if c.Config.GetEnforcePKCEForPublicClients(ctx) && client.IsPublic() { + return errorsx.WithStack(fosite.ErrInvalidRequest. + WithHint("This client must include a code_challenge when performing the authorize code flow, but it is missing."). + WithDebug("The server is configured in a way that enforces PKCE for this client.")) + } + return nil + } + + // If the server supporting PKCE does not support the requested + // transformation, the authorization endpoint MUST return the + // authorization error response with "error" value set to + // "invalid_request". The "error_description" or the response of + // "error_uri" SHOULD explain the nature of error, e.g., transform + // algorithm not supported. + switch method { + case "S256": + break + case "plain": + fallthrough + case "": + if !c.Config.GetEnablePKCEPlainChallengeMethod(ctx) { + return errorsx.WithStack(fosite.ErrInvalidRequest. + WithHint("Clients must use code_challenge_method=S256, plain is not allowed."). + WithDebug("The server is configured in a way that enforces PKCE S256 as challenge method for clients.")) + } + default: + return errorsx.WithStack(fosite.ErrInvalidRequest. + WithHint("The code_challenge_method is not supported, use S256 instead.")) + } + return nil +} + +func (c *HandlerDevice) HandleTokenEndpointRequest(ctx context.Context, request fosite.AccessRequester) error { + if !c.CanHandleTokenEndpointRequest(ctx, request) { + return errorsx.WithStack(fosite.ErrUnknownRequest) + } + + // code_verifier + // REQUIRED. Code verifier + // + // The "code_challenge_method" is bound to the Authorization Code when + // the Authorization Code is issued. That is the method that the token + // endpoint MUST use to verify the "code_verifier". + verifier := request.GetRequestForm().Get("code_verifier") + + code := request.GetRequestForm().Get("device_code") + if code == "" { + return errorsx.WithStack(errorsx.WithStack(fosite.ErrUnknownRequest.WithHint("device_code missing form body"))) + } + codeSignature := c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) + + fmt.Println("PKCE ID : " + codeSignature) + + authorizeRequest, err := c.Storage.GetPKCERequestSession(ctx, codeSignature, request.GetSession()) + if errors.Is(err, fosite.ErrNotFound) { + return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("Unable to find initial PKCE data tied to this request").WithWrap(err).WithDebug(err.Error())) + } else if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + if err := c.Storage.DeletePKCERequestSession(ctx, codeSignature); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + challenge := authorizeRequest.GetRequestForm().Get("code_challenge") + method := authorizeRequest.GetRequestForm().Get("code_challenge_method") + client := authorizeRequest.GetClient() + if err := c.validate(ctx, challenge, method, client); err != nil { + return err + } + + if !c.Config.GetEnforcePKCE(ctx) && challenge == "" && verifier == "" { + return nil + } + + // NOTE: The code verifier SHOULD have enough entropy to make it + // impractical to guess the value. It is RECOMMENDED that the output of + // a suitable random number generator be used to create a 32-octet + // sequence. The octet sequence is then base64url-encoded to produce a + // 43-octet URL safe string to use as the code verifier. + + // Validation + if len(verifier) < 43 { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The PKCE code verifier must be at least 43 characters.")) + } else if len(verifier) > 128 { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The PKCE code verifier can not be longer than 128 characters.")) + } else if verifierWrongFormat.MatchString(verifier) { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The PKCE code verifier must only contain [a-Z], [0-9], '-', '.', '_', '~'.")) + } + + // Upon receipt of the request at the token endpoint, the server + // verifies it by calculating the code challenge from the received + // "code_verifier" and comparing it with the previously associated + // "code_challenge", after first transforming it according to the + // "code_challenge_method" method specified by the client. + // + // If the "code_challenge_method" from Section 4.3 was "S256", the + // received "code_verifier" is hashed by SHA-256, base64url-encoded, and + // then compared to the "code_challenge", i.e.: + // + // BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) == code_challenge + // + // If the "code_challenge_method" from Section 4.3 was "plain", they are + // compared directly, i.e.: + // + // code_verifier == code_challenge. + // + // If the values are equal, the token endpoint MUST continue processing + // as normal (as defined by OAuth 2.0 [RFC6749]). If the values are not + // equal, an error response indicating "invalid_grant" as described in + // Section 5.2 of [RFC6749] MUST be returned. + switch method { + case "S256": + hash := sha256.New() + if _, err := hash.Write([]byte(verifier)); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + if base64.RawURLEncoding.EncodeToString(hash.Sum([]byte{})) != challenge { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The PKCE code challenge did not match the code verifier.")) + } + break + case "plain": + fallthrough + default: + if verifier != challenge { + return errorsx.WithStack(fosite.ErrInvalidGrant. + WithHint("The PKCE code challenge did not match the code verifier.")) + } + } + + return nil +} + +func (c *HandlerDevice) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { + return nil +} + +func (c *HandlerDevice) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { + return false +} + +func (c *HandlerDevice) CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool { + fmt.Println("CanHandleTokenEndpointRequest PKCE") + // grant_type REQUIRED. + // Value MUST be set to "authorization_code" + return requester.GetGrantTypes().ExactOne("urn:ietf:params:oauth:grant-type:device_code") +} diff --git a/internal/access_request.go b/internal/access_request.go index 7f408e080..149b920aa 100644 --- a/internal/access_request.go +++ b/internal/access_request.go @@ -10,7 +10,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/access_response.go b/internal/access_response.go index 9c383cb79..cdd100769 100644 --- a/internal/access_response.go +++ b/internal/access_response.go @@ -9,7 +9,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/access_token_storage.go b/internal/access_token_storage.go index a30404098..7d4d80549 100644 --- a/internal/access_token_storage.go +++ b/internal/access_token_storage.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/access_token_strategy.go b/internal/access_token_strategy.go index a811e51cf..965404f64 100644 --- a/internal/access_token_strategy.go +++ b/internal/access_token_strategy.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/authorize_code_storage.go b/internal/authorize_code_storage.go index ef3ab3b3c..75e62d4ac 100644 --- a/internal/authorize_code_storage.go +++ b/internal/authorize_code_storage.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/authorize_code_strategy.go b/internal/authorize_code_strategy.go index cb826a23d..8baf3a2c4 100644 --- a/internal/authorize_code_strategy.go +++ b/internal/authorize_code_strategy.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/authorize_handler.go b/internal/authorize_handler.go index 964fb7e4a..a660d7ebb 100644 --- a/internal/authorize_handler.go +++ b/internal/authorize_handler.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/authorize_request.go b/internal/authorize_request.go index a643e04e8..e3d2f959b 100644 --- a/internal/authorize_request.go +++ b/internal/authorize_request.go @@ -10,7 +10,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/client.go b/internal/client.go index 771660924..b2299dc02 100644 --- a/internal/client.go +++ b/internal/client.go @@ -6,10 +6,8 @@ package internal import ( reflect "reflect" - time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) @@ -134,20 +132,6 @@ func (mr *MockClientMockRecorder) GetScopes() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScopes", reflect.TypeOf((*MockClient)(nil).GetScopes)) } -// GetTokenLifespan mocks base method. -func (m *MockClient) GetTokenLifespan(arg0 fosite.GrantType, arg1 fosite.TokenType, arg2 time.Duration) time.Duration { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTokenLifespan", arg0, arg1, arg2) - ret0, _ := ret[0].(time.Duration) - return ret0 -} - -// GetTokenLifespan indicates an expected call of GetTokenLifespan. -func (mr *MockClientMockRecorder) GetTokenLifespan(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokenLifespan", reflect.TypeOf((*MockClient)(nil).GetTokenLifespan), arg0, arg1, arg2) -} - // IsPublic mocks base method. func (m *MockClient) IsPublic() bool { m.ctrl.T.Helper() @@ -161,15 +145,3 @@ func (mr *MockClientMockRecorder) IsPublic() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublic", reflect.TypeOf((*MockClient)(nil).IsPublic)) } - -// SetTokenLifespans mocks base method. -func (m *MockClient) SetTokenLifespans(arg0 map[fosite.TokenType]time.Duration) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetTokenLifespans", arg0) -} - -// SetTokenLifespans indicates an expected call of SetTokenLifespans. -func (mr *MockClientMockRecorder) SetTokenLifespans(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTokenLifespans", reflect.TypeOf((*MockClient)(nil).SetTokenLifespans), arg0) -} diff --git a/internal/id_token_strategy.go b/internal/id_token_strategy.go index 2e36d1488..9ec9b8ce6 100644 --- a/internal/id_token_strategy.go +++ b/internal/id_token_strategy.go @@ -10,7 +10,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/introspector.go b/internal/introspector.go index c47cf76e6..6744ec5f7 100644 --- a/internal/introspector.go +++ b/internal/introspector.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/oauth2_client_storage.go b/internal/oauth2_client_storage.go index 54770c89d..bb15e6c53 100644 --- a/internal/oauth2_client_storage.go +++ b/internal/oauth2_client_storage.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/oauth2_owner_storage.go b/internal/oauth2_owner_storage.go index ce53de4c8..37ace932f 100644 --- a/internal/oauth2_owner_storage.go +++ b/internal/oauth2_owner_storage.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/oauth2_revoke_storage.go b/internal/oauth2_revoke_storage.go index 71f8e10a5..c4b28282a 100644 --- a/internal/oauth2_revoke_storage.go +++ b/internal/oauth2_revoke_storage.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/oauth2_storage.go b/internal/oauth2_storage.go index d29457603..0b9789d53 100644 --- a/internal/oauth2_storage.go +++ b/internal/oauth2_storage.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) @@ -64,6 +63,20 @@ func (mr *MockCoreStorageMockRecorder) CreateAuthorizeCodeSession(arg0, arg1, ar return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthorizeCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).CreateAuthorizeCodeSession), arg0, arg1, arg2) } +// CreateDeviceCodeSession mocks base method. +func (m *MockCoreStorage) CreateDeviceCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateDeviceCodeSession", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateDeviceCodeSession indicates an expected call of CreateDeviceCodeSession. +func (mr *MockCoreStorageMockRecorder) CreateDeviceCodeSession(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDeviceCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).CreateDeviceCodeSession), arg0, arg1, arg2) +} + // CreateRefreshTokenSession mocks base method. func (m *MockCoreStorage) CreateRefreshTokenSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { m.ctrl.T.Helper() @@ -78,6 +91,20 @@ func (mr *MockCoreStorageMockRecorder) CreateRefreshTokenSession(arg0, arg1, arg return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshTokenSession", reflect.TypeOf((*MockCoreStorage)(nil).CreateRefreshTokenSession), arg0, arg1, arg2) } +// CreateUserCodeSession mocks base method. +func (m *MockCoreStorage) CreateUserCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUserCodeSession", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateUserCodeSession indicates an expected call of CreateUserCodeSession. +func (mr *MockCoreStorageMockRecorder) CreateUserCodeSession(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUserCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).CreateUserCodeSession), arg0, arg1, arg2) +} + // DeleteAccessTokenSession mocks base method. func (m *MockCoreStorage) DeleteAccessTokenSession(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() @@ -92,6 +119,20 @@ func (mr *MockCoreStorageMockRecorder) DeleteAccessTokenSession(arg0, arg1 inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccessTokenSession", reflect.TypeOf((*MockCoreStorage)(nil).DeleteAccessTokenSession), arg0, arg1) } +// DeleteDeviceCodeSession mocks base method. +func (m *MockCoreStorage) DeleteDeviceCodeSession(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteDeviceCodeSession", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteDeviceCodeSession indicates an expected call of DeleteDeviceCodeSession. +func (mr *MockCoreStorageMockRecorder) DeleteDeviceCodeSession(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDeviceCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).DeleteDeviceCodeSession), arg0, arg1) +} + // DeleteRefreshTokenSession mocks base method. func (m *MockCoreStorage) DeleteRefreshTokenSession(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() @@ -106,6 +147,20 @@ func (mr *MockCoreStorageMockRecorder) DeleteRefreshTokenSession(arg0, arg1 inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRefreshTokenSession", reflect.TypeOf((*MockCoreStorage)(nil).DeleteRefreshTokenSession), arg0, arg1) } +// DeleteUserCodeSession mocks base method. +func (m *MockCoreStorage) DeleteUserCodeSession(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserCodeSession", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUserCodeSession indicates an expected call of DeleteUserCodeSession. +func (mr *MockCoreStorageMockRecorder) DeleteUserCodeSession(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).DeleteUserCodeSession), arg0, arg1) +} + // GetAccessTokenSession mocks base method. func (m *MockCoreStorage) GetAccessTokenSession(arg0 context.Context, arg1 string, arg2 fosite.Session) (fosite.Requester, error) { m.ctrl.T.Helper() @@ -136,6 +191,21 @@ func (mr *MockCoreStorageMockRecorder) GetAuthorizeCodeSession(arg0, arg1, arg2 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizeCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).GetAuthorizeCodeSession), arg0, arg1, arg2) } +// GetDeviceCodeSession mocks base method. +func (m *MockCoreStorage) GetDeviceCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Session) (fosite.Requester, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDeviceCodeSession", arg0, arg1, arg2) + ret0, _ := ret[0].(fosite.Requester) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDeviceCodeSession indicates an expected call of GetDeviceCodeSession. +func (mr *MockCoreStorageMockRecorder) GetDeviceCodeSession(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeviceCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).GetDeviceCodeSession), arg0, arg1, arg2) +} + // GetRefreshTokenSession mocks base method. func (m *MockCoreStorage) GetRefreshTokenSession(arg0 context.Context, arg1 string, arg2 fosite.Session) (fosite.Requester, error) { m.ctrl.T.Helper() @@ -151,6 +221,21 @@ func (mr *MockCoreStorageMockRecorder) GetRefreshTokenSession(arg0, arg1, arg2 i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRefreshTokenSession", reflect.TypeOf((*MockCoreStorage)(nil).GetRefreshTokenSession), arg0, arg1, arg2) } +// GetUserCodeSession mocks base method. +func (m *MockCoreStorage) GetUserCodeSession(arg0 context.Context, arg1 string, arg2 fosite.Session) (fosite.Requester, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserCodeSession", arg0, arg1, arg2) + ret0, _ := ret[0].(fosite.Requester) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserCodeSession indicates an expected call of GetUserCodeSession. +func (mr *MockCoreStorageMockRecorder) GetUserCodeSession(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).GetUserCodeSession), arg0, arg1, arg2) +} + // InvalidateAuthorizeCodeSession mocks base method. func (m *MockCoreStorage) InvalidateAuthorizeCodeSession(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() diff --git a/internal/oauth2_strategy.go b/internal/oauth2_strategy.go index 8934ff245..3a6bc859c 100644 --- a/internal/oauth2_strategy.go +++ b/internal/oauth2_strategy.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) @@ -64,6 +63,20 @@ func (mr *MockCoreStrategyMockRecorder) AuthorizeCodeSignature(arg0, arg1 interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeCodeSignature", reflect.TypeOf((*MockCoreStrategy)(nil).AuthorizeCodeSignature), arg0, arg1) } +// DeviceCodeSignature mocks base method. +func (m *MockCoreStrategy) DeviceCodeSignature(arg0 context.Context, arg1 string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeviceCodeSignature", arg0, arg1) + ret0, _ := ret[0].(string) + return ret0 +} + +// DeviceCodeSignature indicates an expected call of DeviceCodeSignature. +func (mr *MockCoreStrategyMockRecorder) DeviceCodeSignature(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceCodeSignature", reflect.TypeOf((*MockCoreStrategy)(nil).DeviceCodeSignature), arg0, arg1) +} + // GenerateAccessToken mocks base method. func (m *MockCoreStrategy) GenerateAccessToken(arg0 context.Context, arg1 fosite.Requester) (string, string, error) { m.ctrl.T.Helper() @@ -96,6 +109,21 @@ func (mr *MockCoreStrategyMockRecorder) GenerateAuthorizeCode(arg0, arg1 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateAuthorizeCode", reflect.TypeOf((*MockCoreStrategy)(nil).GenerateAuthorizeCode), arg0, arg1) } +// GenerateDeviceCode mocks base method. +func (m *MockCoreStrategy) GenerateDeviceCode() (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateDeviceCode") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateDeviceCode indicates an expected call of GenerateDeviceCode. +func (mr *MockCoreStrategyMockRecorder) GenerateDeviceCode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateDeviceCode", reflect.TypeOf((*MockCoreStrategy)(nil).GenerateDeviceCode)) +} + // GenerateRefreshToken mocks base method. func (m *MockCoreStrategy) GenerateRefreshToken(arg0 context.Context, arg1 fosite.Requester) (string, string, error) { m.ctrl.T.Helper() @@ -112,6 +140,21 @@ func (mr *MockCoreStrategyMockRecorder) GenerateRefreshToken(arg0, arg1 interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateRefreshToken", reflect.TypeOf((*MockCoreStrategy)(nil).GenerateRefreshToken), arg0, arg1) } +// GenerateUserCode mocks base method. +func (m *MockCoreStrategy) GenerateUserCode() (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateUserCode") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateUserCode indicates an expected call of GenerateUserCode. +func (mr *MockCoreStrategyMockRecorder) GenerateUserCode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateUserCode", reflect.TypeOf((*MockCoreStrategy)(nil).GenerateUserCode)) +} + // RefreshTokenSignature mocks base method. func (m *MockCoreStrategy) RefreshTokenSignature(arg0 context.Context, arg1 string) string { m.ctrl.T.Helper() @@ -126,6 +169,20 @@ func (mr *MockCoreStrategyMockRecorder) RefreshTokenSignature(arg0, arg1 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTokenSignature", reflect.TypeOf((*MockCoreStrategy)(nil).RefreshTokenSignature), arg0, arg1) } +// UserCodeSignature mocks base method. +func (m *MockCoreStrategy) UserCodeSignature(arg0 context.Context, arg1 string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UserCodeSignature", arg0, arg1) + ret0, _ := ret[0].(string) + return ret0 +} + +// UserCodeSignature indicates an expected call of UserCodeSignature. +func (mr *MockCoreStrategyMockRecorder) UserCodeSignature(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeSignature", reflect.TypeOf((*MockCoreStrategy)(nil).UserCodeSignature), arg0, arg1) +} + // ValidateAccessToken mocks base method. func (m *MockCoreStrategy) ValidateAccessToken(arg0 context.Context, arg1 fosite.Requester, arg2 string) error { m.ctrl.T.Helper() @@ -154,6 +211,20 @@ func (mr *MockCoreStrategyMockRecorder) ValidateAuthorizeCode(arg0, arg1, arg2 i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateAuthorizeCode", reflect.TypeOf((*MockCoreStrategy)(nil).ValidateAuthorizeCode), arg0, arg1, arg2) } +// ValidateDeviceCode mocks base method. +func (m *MockCoreStrategy) ValidateDeviceCode(arg0 context.Context, arg1 fosite.Requester, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateDeviceCode", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateDeviceCode indicates an expected call of ValidateDeviceCode. +func (mr *MockCoreStrategyMockRecorder) ValidateDeviceCode(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateDeviceCode", reflect.TypeOf((*MockCoreStrategy)(nil).ValidateDeviceCode), arg0, arg1, arg2) +} + // ValidateRefreshToken mocks base method. func (m *MockCoreStrategy) ValidateRefreshToken(arg0 context.Context, arg1 fosite.Requester, arg2 string) error { m.ctrl.T.Helper() @@ -167,3 +238,17 @@ func (mr *MockCoreStrategyMockRecorder) ValidateRefreshToken(arg0, arg1, arg2 in mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateRefreshToken", reflect.TypeOf((*MockCoreStrategy)(nil).ValidateRefreshToken), arg0, arg1, arg2) } + +// ValidateUserCode mocks base method. +func (m *MockCoreStrategy) ValidateUserCode(arg0 context.Context, arg1 fosite.Requester, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateUserCode", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateUserCode indicates an expected call of ValidateUserCode. +func (mr *MockCoreStrategyMockRecorder) ValidateUserCode(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserCode", reflect.TypeOf((*MockCoreStrategy)(nil).ValidateUserCode), arg0, arg1, arg2) +} diff --git a/internal/openid_id_token_storage.go b/internal/openid_id_token_storage.go index 9d33b855f..11f63c68d 100644 --- a/internal/openid_id_token_storage.go +++ b/internal/openid_id_token_storage.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/pkce_storage_strategy.go b/internal/pkce_storage_strategy.go index 1ab44f9ad..0f25f14f4 100644 --- a/internal/pkce_storage_strategy.go +++ b/internal/pkce_storage_strategy.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/refresh_token_strategy.go b/internal/refresh_token_strategy.go index d0e0422e5..99951c794 100644 --- a/internal/refresh_token_strategy.go +++ b/internal/refresh_token_strategy.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/request.go b/internal/request.go index ce642cb0d..b833be8ba 100644 --- a/internal/request.go +++ b/internal/request.go @@ -10,7 +10,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/revoke_handler.go b/internal/revoke_handler.go index 75418fc70..ed542c112 100644 --- a/internal/revoke_handler.go +++ b/internal/revoke_handler.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/storage.go b/internal/storage.go index a17da5f7e..e7b934cc4 100644 --- a/internal/storage.go +++ b/internal/storage.go @@ -10,7 +10,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/token_handler.go b/internal/token_handler.go index 5e31065d3..5ad46ac60 100644 --- a/internal/token_handler.go +++ b/internal/token_handler.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/oauth2.go b/oauth2.go index 66a539b06..bf03fbe22 100644 --- a/oauth2.go +++ b/oauth2.go @@ -41,6 +41,8 @@ const ( RefreshToken TokenType = "refresh_token" AuthorizeCode TokenType = "authorize_code" IDToken TokenType = "id_token" + DeviceCode TokenType = "device_code" + UserCode TokenType = "user_code" // PushedAuthorizeRequestContext represents the PAR context object PushedAuthorizeRequestContext TokenType = "par_context" @@ -185,6 +187,11 @@ type OAuth2Provider interface { // WritePushedAuthorizeError writes the PAR error WritePushedAuthorizeError(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, err error) + + // ToDo add ietf docs + NewDeviceAuthorizeRequest(ctx context.Context, req *http.Request) (Requester, error) + NewDeviceAuthorizeResponse(ctx context.Context, requester Requester) (DeviceAuthorizeResponder, error) + WriteDeviceAuthorizeResponse(rw http.ResponseWriter, requester Requester, responder DeviceAuthorizeResponder) } // IntrospectionResponder is the response object that will be returned when token introspection was successful, @@ -383,3 +390,23 @@ type G11NContext interface { // GetLang returns the current language in the context GetLang() language.Tag } + +type DeviceAuthorizeResponder interface { + GetDeviceCode() string + SetDeviceCode(code string) + + GetUserCode() string + SetUserCode(code string) + + GetVerificationURI() string + SetVerificationURI(uri string) + + GetVerificationURIComplete() string + SetVerificationURIComplete(uri string) + + GetExpiresIn() int64 + SetExpiresIn(seconds int64) + + GetInterval() int + SetInterval(seconds int) +} diff --git a/storage/memory.go b/storage/memory.go index 41b4f86a3..7557984ab 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -59,12 +59,16 @@ type MemoryStore struct { IDSessions map[string]fosite.Requester AccessTokens map[string]fosite.Requester RefreshTokens map[string]StoreRefreshToken + DeviceCodes map[string]fosite.Requester + UserCodes map[string]fosite.Requester PKCES map[string]fosite.Requester Users map[string]MemoryUserRelation BlacklistedJTIs map[string]time.Time // In-memory request ID to token signatures AccessTokenRequestIDs map[string]string RefreshTokenRequestIDs map[string]string + DeviceCodesRequestIDs map[string]string + UserCodesRequestIDs map[string]string // Public keys to check signature in auth grant jwt assertion. IssuerPublicKeys map[string]IssuerPublicKeys PARSessions map[string]fosite.AuthorizeRequester @@ -74,11 +78,15 @@ type MemoryStore struct { idSessionsMutex sync.RWMutex accessTokensMutex sync.RWMutex refreshTokensMutex sync.RWMutex + userCodesMutex sync.RWMutex + deviceCodesMutex sync.RWMutex pkcesMutex sync.RWMutex usersMutex sync.RWMutex blacklistedJTIsMutex sync.RWMutex accessTokenRequestIDsMutex sync.RWMutex refreshTokenRequestIDsMutex sync.RWMutex + userCodesRequestIDsMutex sync.RWMutex + deviceCodesRequestIDsMutex sync.RWMutex issuerPublicKeysMutex sync.RWMutex parSessionsMutex sync.RWMutex } @@ -90,10 +98,14 @@ func NewMemoryStore() *MemoryStore { IDSessions: make(map[string]fosite.Requester), AccessTokens: make(map[string]fosite.Requester), RefreshTokens: make(map[string]StoreRefreshToken), + DeviceCodes: make(map[string]fosite.Requester), + UserCodes: make(map[string]fosite.Requester), PKCES: make(map[string]fosite.Requester), Users: make(map[string]MemoryUserRelation), AccessTokenRequestIDs: make(map[string]string), RefreshTokenRequestIDs: make(map[string]string), + DeviceCodesRequestIDs: make(map[string]string), + UserCodesRequestIDs: make(map[string]string), BlacklistedJTIs: make(map[string]time.Time), IssuerPublicKeys: make(map[string]IssuerPublicKeys), PARSessions: make(map[string]fosite.AuthorizeRequester), @@ -157,8 +169,12 @@ func NewExampleStore() *MemoryStore { AccessTokens: map[string]fosite.Requester{}, RefreshTokens: map[string]StoreRefreshToken{}, PKCES: map[string]fosite.Requester{}, + DeviceCodes: make(map[string]fosite.Requester), + UserCodes: make(map[string]fosite.Requester), AccessTokenRequestIDs: map[string]string{}, RefreshTokenRequestIDs: map[string]string{}, + DeviceCodesRequestIDs: make(map[string]string), + UserCodesRequestIDs: make(map[string]string), IssuerPublicKeys: map[string]IssuerPublicKeys{}, } } @@ -514,3 +530,67 @@ func (s *MemoryStore) DeletePARSession(ctx context.Context, requestURI string) ( delete(s.PARSessions, requestURI) return nil } + +func (s *MemoryStore) CreateDeviceCodeSession(_ context.Context, signature string, req fosite.Requester) error { + // We first lock accessTokenRequestIDsMutex and then accessTokensMutex because this is the same order + // locking happens in RevokeAccessToken and using the same order prevents deadlocks. + s.deviceCodesRequestIDsMutex.Lock() + defer s.deviceCodesRequestIDsMutex.Unlock() + s.deviceCodesMutex.Lock() + defer s.deviceCodesMutex.Unlock() + + s.DeviceCodes[signature] = req + s.DeviceCodesRequestIDs[req.GetID()] = signature + return nil +} + +func (s *MemoryStore) GetDeviceCodeSession(_ context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { + s.deviceCodesMutex.RLock() + defer s.deviceCodesMutex.RUnlock() + + rel, ok := s.DeviceCodes[signature] + if !ok { + return nil, fosite.ErrNotFound + } + return rel, nil +} + +func (s *MemoryStore) DeleteDeviceCodeSession(_ context.Context, code string) error { + s.deviceCodesMutex.Lock() + defer s.deviceCodesMutex.Unlock() + + delete(s.DeviceCodes, code) + return nil +} + +func (s *MemoryStore) CreateUserCodeSession(_ context.Context, signature string, req fosite.Requester) error { + // We first lock accessTokenRequestIDsMutex and then accessTokensMutex because this is the same order + // locking happens in RevokeAccessToken and using the same order prevents deadlocks. + s.accessTokenRequestIDsMutex.Lock() + defer s.accessTokenRequestIDsMutex.Unlock() + s.accessTokensMutex.Lock() + defer s.accessTokensMutex.Unlock() + + s.AccessTokens[signature] = req + s.AccessTokenRequestIDs[req.GetID()] = signature + return nil +} + +func (s *MemoryStore) GetUserCodeSession(_ context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { + s.accessTokensMutex.RLock() + defer s.accessTokensMutex.RUnlock() + + rel, ok := s.AccessTokens[signature] + if !ok { + return nil, fosite.ErrNotFound + } + return rel, nil +} + +func (s *MemoryStore) DeleteUserCodeSession(_ context.Context, code string) error { + s.userCodesMutex.Lock() + defer s.userCodesMutex.Unlock() + + delete(s.UserCodes, code) + return nil +} diff --git a/token/hmac/hmacsha.go b/token/hmac/hmacsha.go index 5910a5eb6..b6cd61da7 100644 --- a/token/hmac/hmacsha.go +++ b/token/hmac/hmacsha.go @@ -173,6 +173,17 @@ func (c *HMACStrategy) Signature(token string) string { return split[1] } +func (c *HMACStrategy) GenerateHMACForString(text string, ctx context.Context) string { + var signingKey [32]byte + copy(signingKey[:], c.Config.GetGlobalSecret(ctx)) + + bytes := []byte(text) + hashBytes := c.generateHMAC(ctx, bytes, &signingKey) + + b64 := base64.URLEncoding.EncodeToString(hashBytes) + return b64 +} + func (c *HMACStrategy) generateHMAC(ctx context.Context, data []byte, key *[32]byte) []byte { hasher := c.Config.GetHMACHasher(ctx) if hasher == nil { diff --git a/token/hmac/hmacsha_test.go b/token/hmac/hmacsha_test.go index 23de6764a..2dcbc71b9 100644 --- a/token/hmac/hmacsha_test.go +++ b/token/hmac/hmacsha_test.go @@ -151,3 +151,33 @@ func TestCustomHMAC(t *testing.T) { require.NoError(t, sha512.Validate(context.Background(), token512)) require.EqualError(t, def.Validate(context.Background(), token512), fosite.ErrTokenSignatureMismatch.Error()) } + +func TestGenerateFromString(t *testing.T) { + cg := HMACStrategy{Config: &fosite.Config{ + GlobalSecret: []byte("1234567890123456789012345678901234567890")}, + } + for _, c := range []struct { + text string + hash string + }{ + { + text: "", + hash: "-n7EqD-bXkY3yYMH-ctEAGV8XLkU7Y6Bo6pbyT1agGA=", + }, + { + text: " ", + hash: "zXJvonHTNSOOGj_QKl4RpIX_zXgD2YfXUfwuDKaTTIg=", + }, + { + text: "Test", + hash: "TMeEaHS-cDC2nijiesCNtsOyBqHHtzWqAcWvceQT50g=", + }, + { + text: "AnotherTest1234", + hash: "zHYDOZGjzhVjx5r8RlBhpnJemX5JxEEBUjVT01n3IFM=", + }, + } { + hash := cg.GenerateHMACForString(c.text, context.Background()) + assert.Equal(t, c.hash, hash) + } +}