Skip to content

Commit

Permalink
feat: login and registration code method config
Browse files Browse the repository at this point in the history
  • Loading branch information
Benehiko committed Jul 17, 2023
1 parent 929e1e8 commit 8d6e952
Show file tree
Hide file tree
Showing 27 changed files with 206 additions and 77 deletions.
49 changes: 43 additions & 6 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,13 @@ type (
Config json.RawMessage `json:"config"`
}
SelfServiceStrategy struct {
Enabled bool `json:"enabled"`
Config json.RawMessage `json:"config"`
AllowedFlows []string `json:"allowed_flows"`
Enabled bool `json:"enabled"`
Config json.RawMessage `json:"config"`
}
SelfServiceStrategyCode struct {
RegistrationEnabled bool `json:"registration_enabled"`
LoginEnabled bool `json:"login_enabled"`
*SelfServiceStrategy
}
Schema struct {
ID string `json:"id" koanf:"id"`
Expand Down Expand Up @@ -732,9 +736,8 @@ func (p *Config) SelfServiceStrategy(ctx context.Context, strategy string) *Self

enabledKey := fmt.Sprintf("%s.enabled", basePath)
s := &SelfServiceStrategy{
Enabled: pp.Bool(enabledKey),
Config: json.RawMessage(config),
AllowedFlows: pp.Strings(fmt.Sprintf("%s.allowed_flows", basePath)),
Enabled: pp.Bool(enabledKey),
Config: json.RawMessage(config),
}

// The default value can easily be overwritten by setting e.g. `{"selfservice": "null"}` which means that
Expand All @@ -758,6 +761,40 @@ func (p *Config) SelfServiceStrategy(ctx context.Context, strategy string) *Self
return s
}

func (p *Config) SelfServiceCodeStrategy(ctx context.Context) *SelfServiceStrategyCode {
pp := p.GetProvider(ctx)

config := "{}"
out, err := pp.Marshal(kjson.Parser())
if err != nil {
p.l.WithError(err).Warn("Unable to marshal self service strategy configuration.")
} else if c := gjson.GetBytes(out,
fmt.Sprintf("%s.%s.config", ViperKeySelfServiceStrategyConfig, "code")).Raw; len(c) > 0 {
config = c
}

basePath := fmt.Sprintf("%s.%s", ViperKeySelfServiceStrategyConfig, "code")
enabledKey := fmt.Sprintf("%s.enabled", basePath)
registrationKey := fmt.Sprintf("%s.registration_enabled", basePath)
loginKey := fmt.Sprintf("%s.login_enabled", basePath)

s := &SelfServiceStrategyCode{
SelfServiceStrategy: &SelfServiceStrategy{
Enabled: pp.Bool(enabledKey),
Config: json.RawMessage(config),
},
RegistrationEnabled: pp.Bool(registrationKey),
LoginEnabled: pp.Bool(loginKey),
}

if !pp.Exists(enabledKey) {
s.RegistrationEnabled = false
s.LoginEnabled = false
s.Enabled = true
}
return s
}

func (p *Config) SecretsDefault(ctx context.Context) [][]byte {
pp := p.GetProvider(ctx)
secrets := pp.Strings(ViperKeySecretsDefault)
Expand Down
34 changes: 20 additions & 14 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,28 @@ func (m *RegistryDefault) selfServiceStrategies() []interface{} {
return m.selfserviceStrategies
}

func (m *RegistryDefault) strategyRegistrationEnabled(ctx context.Context, id string) bool {
switch id {
case identity.CredentialsTypeCodeAuth.String():
return m.Config().SelfServiceCodeStrategy(ctx).RegistrationEnabled
default:
return m.Config().SelfServiceStrategy(ctx, id).Enabled
}
}

func (m *RegistryDefault) strategyLoginEnabled(ctx context.Context, id string) bool {
switch id {
case identity.CredentialsTypeCodeAuth.String():
return m.Config().SelfServiceCodeStrategy(ctx).LoginEnabled
default:
return m.Config().SelfServiceStrategy(ctx, id).Enabled
}
}

func (m *RegistryDefault) RegistrationStrategies(ctx context.Context) (registrationStrategies registration.Strategies) {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(registration.Strategy); ok {
// the code method needs to be checked explicitly for registration
// TODO: we need to somehow check if the `code` strategy is enabled specifically for registration
// if s.ID() == identity.CredentialsTypeCodeAuth && m.Config().SelfServiceStrategy(ctx, string(s.ID())).RegistrationEnabled {
// registrationStrategies = append(registrationStrategies, s)
// continue
// }
if m.Config().SelfServiceStrategy(ctx, string(s.ID())).Enabled {
if m.strategyRegistrationEnabled(ctx, s.ID().String()) {
registrationStrategies = append(registrationStrategies, s)
}
}
Expand All @@ -358,13 +370,7 @@ func (m *RegistryDefault) AllRegistrationStrategies() registration.Strategies {
func (m *RegistryDefault) LoginStrategies(ctx context.Context) (loginStrategies login.Strategies) {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(login.Strategy); ok {
// the code method needs to be checked explicity for login
// TODO: we need to somwhow check if the `code` strategy is enabled specifically for login
// if s.ID() == identity.CredentialsTypeCodeAuth && m.Config().SelfServiceStrategy(ctx, string(s.ID())).LoginEnabled {
// loginStrategies = append(loginStrategies, s)
// continue
// }
if m.Config().SelfServiceStrategy(ctx, string(s.ID())).Enabled {
if m.strategyLoginEnabled(ctx, s.ID().String()) {
loginStrategies = append(loginStrategies, s)
}
}
Expand Down
3 changes: 1 addition & 2 deletions driver/registry_default_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ func (m *RegistryDefault) PostRegistrationPostPersistHooks(ctx context.Context,
initialHookCount = 1
}

// TODO: this needs to be specific to the flow and not just the `code` general strategy
if m.Config().SelfServiceStrategy(ctx, identity.CredentialsTypeCodeAuth.String()).Enabled {
if credentialsType == identity.CredentialsTypeCodeAuth && m.Config().SelfServiceCodeStrategy(ctx).RegistrationEnabled {
b = append(b, m.HookCodeAddressVerifier())
initialHookCount += 1
}
Expand Down
10 changes: 10 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,16 @@
"type": "object",
"additionalProperties": false,
"properties": {
"login_enabled": {
"type": "boolean",
"title": "Enables Login with Code Method",
"default": false
},
"registration_method": {
"type": "boolean",
"title": "Enables Registration with Code Method",
"default": false
},
"enabled": {
"type": "boolean",
"title": "Enables Code Method",
Expand Down
42 changes: 20 additions & 22 deletions selfservice/flow/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/strategy"
"github.com/ory/x/decoderx"

Expand Down Expand Up @@ -77,7 +78,7 @@ func EnsureCSRF(reg interface {

var dec = decoderx.NewHTTP()

func MethodEnabledAndAllowedFromRequest(r *http.Request, expected string, d interface {
func MethodEnabledAndAllowedFromRequest(r *http.Request, flow FlowName, expected string, d interface {
config.Provider
},
) error {
Expand All @@ -98,41 +99,38 @@ func MethodEnabledAndAllowedFromRequest(r *http.Request, expected string, d inte
return errors.WithStack(err)
}

return MethodEnabledAndAllowed(r.Context(), expected, method.Method, d)
return MethodEnabledAndAllowed(r.Context(), flow, expected, method.Method, d)
}

// TODO: to disable specific flows we need to pass down the flow somehow to this method
// we could do this by adding an additional parameter, but not all methods have access to the flow
// this adds a lot of refactoring work, so we should think about a better way to do this
func MethodEnabledAndAllowed(ctx context.Context, expected, actual string, d interface {
func MethodEnabledAndAllowed(ctx context.Context, flowName FlowName, expected, actual string, d interface {
config.Provider
},
) error {
if actual != expected {
return errors.WithStack(ErrStrategyNotResponsible)
}

stratConf := d.Config().SelfServiceStrategy(ctx, expected)
var ok bool

err := herodot.ErrNotFound.WithReason(strategy.EndpointDisabledMessage)
if strings.EqualFold(actual, identity.CredentialsTypeCodeAuth.String()) {
switch flowName {
case RegistrationFlow:
ok = d.Config().SelfServiceCodeStrategy(ctx).RegistrationEnabled
case LoginFlow:
ok = d.Config().SelfServiceCodeStrategy(ctx).LoginEnabled
default:
ok = d.Config().SelfServiceCodeStrategy(ctx).Enabled
}
} else {
ok = d.Config().SelfServiceStrategy(ctx, expected).Enabled
}

if stratConf.Enabled {
return nil
if !ok {
return herodot.ErrNotFound.WithReason(strategy.EndpointDisabledMessage)
}

// TODO: Implement a way to disable specific flows for this strategy
// For example, with Code strategy we might only want to allow login flows or recovery flows.

// if len(stratConf.AllowedFlows) == 0 {
// return nil
// }
//
// // must match one of the allowed flows since this method is enabled
// for _, s := range stratConf.AllowedFlows {
// if strings.EqualFold(s, string(flow.GetFlowName())) {
// return nil
// }
// }

return errors.WithStack(err)
return nil
}
74 changes: 73 additions & 1 deletion selfservice/flow/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestMethodEnabledAndAllowed(t *testing.T) {
ctx := context.Background()
conf, d := internal.NewFastRegistryWithMocks(t)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := flow.MethodEnabledAndAllowedFromRequest(r, "password", d); err != nil {
if err := flow.MethodEnabledAndAllowedFromRequest(r, flow.LoginFlow, "password", d); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -91,3 +91,75 @@ func TestMethodEnabledAndAllowed(t *testing.T) {
assert.Contains(t, string(body), "The requested resource could not be found")
})
}

func TestMethodCodeEnabledAndAllowed(t *testing.T) {
ctx := context.Background()
conf, d := internal.NewFastRegistryWithMocks(t)

currentFlow := make(chan flow.FlowName, 1)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f := <-currentFlow
if err := flow.MethodEnabledAndAllowedFromRequest(r, f, "code", d); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}))

t.Run("login code allowed", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.login_enabled", true)
currentFlow <- flow.LoginFlow
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusNoContent, res.StatusCode)
})

t.Run("login code not allowed", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.login_enabled", false)
currentFlow <- flow.LoginFlow
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
assert.Contains(t, string(body), "The requested resource could not be found")
})

t.Run("registration code allowed", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.registration_enabled", true)
currentFlow <- flow.RegistrationFlow
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusNoContent, res.StatusCode)
})

t.Run("registration code not allowed", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.registration_enabled", false)
currentFlow <- flow.RegistrationFlow
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
assert.Contains(t, string(body), "The requested resource could not be found")
})

t.Run("recovery and verification should still be allowed if registration and login is disabled", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.registration_enabled", false)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.login_enabled", false)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true)

for _, f := range []flow.FlowName{flow.RecoveryFlow, flow.VerificationFlow} {
currentFlow <- f
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusNoContent, res.StatusCode)
}
})
}
6 changes: 5 additions & 1 deletion selfservice/hook/code_address_verifier.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package hook

import (
"net/http"

"github.com/pkg/errors"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/flow/verification"
"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
"github.com/pkg/errors"
)

type (
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/code/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
return nil, err
}

if err := flow.MethodEnabledAndAllowedFromRequest(r, s.ID().String(), s.deps); err != nil {
if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.deps); err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/code/strategy_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F

// If the email is present in the submission body, the user needs a new code via resend
if f.State != flow.StateChooseMethod && len(body.Email) == 0 {
if err := flow.MethodEnabledAndAllowed(ctx, sID, sID, s.deps); err != nil {
if err := flow.MethodEnabledAndAllowed(ctx, flow.RecoveryFlow, sID, sID, s.deps); err != nil {
return s.HandleRecoveryError(w, r, nil, body, err)
}
return s.recoveryUseCode(w, r, body, f)
Expand All @@ -327,7 +327,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
return errors.WithStack(flow.ErrCompletedByStrategy)
}

if err := flow.MethodEnabledAndAllowed(ctx, sID, body.Method, s.deps); err != nil {
if err := flow.MethodEnabledAndAllowed(ctx, flow.RecoveryFlow, sID, body.Method, s.deps); err != nil {
return s.HandleRecoveryError(w, r, nil, body, err)
}

Expand Down
4 changes: 3 additions & 1 deletion selfservice/strategy/code/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ func (s *Strategy) handleIdentityTraits(ctx context.Context, f *registration.Flo
}

func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) error {
if err := flow.MethodEnabledAndAllowedFromRequest(r, s.ID().String(), s.deps); err != nil {
if !s.deps.Config().SelfServiceCodeStrategy(r.Context()).RegistrationEnabled {
}
if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.deps); err != nil {
return err
}

Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/code/strategy_registration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func TestRegistrationCodeStrategy(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/registration.schema.json")
conf.MustSet(ctx, fmt.Sprintf("%s.%s.registration_enabled", config.ViperKeySelfServiceStrategyConfig, string(identity.CredentialsTypeCodeAuth)), false)
conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth.String()), false)
conf.MustSet(ctx, fmt.Sprintf("%s.%s.registration_enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth), true)
conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh")
conf.MustSet(ctx, config.ViperKeyURLsAllowedReturnToDomains, []string{"https://www.ory.sh"})
conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".code.hooks", []map[string]interface{}{
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/code/strategy_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio
return s.handleVerificationError(w, r, nil, body, err)
}

if err := flow.MethodEnabledAndAllowed(r.Context(), s.VerificationStrategyID(), string(body.getMethod()), s.deps); err != nil {
if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), string(body.getMethod()), s.deps); err != nil {
return s.handleVerificationError(w, r, f, body, err)
}

Expand Down
Loading

0 comments on commit 8d6e952

Please sign in to comment.