From 8d6e9522756d33fc3d98dc70a7e8d4d1444ac66f Mon Sep 17 00:00:00 2001 From: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> Date: Mon, 17 Jul 2023 14:26:07 +0200 Subject: [PATCH] feat: login and registration code method config --- driver/config/config.go | 49 ++++++++++-- driver/registry_default.go | 34 +++++---- driver/registry_default_registration.go | 3 +- embedx/config.schema.json | 10 +++ selfservice/flow/request.go | 42 +++++------ selfservice/flow/request_test.go | 74 ++++++++++++++++++- selfservice/hook/code_address_verifier.go | 6 +- selfservice/strategy/code/strategy_login.go | 2 +- .../strategy/code/strategy_recovery.go | 4 +- .../strategy/code/strategy_registration.go | 4 +- .../code/strategy_registration_test.go | 3 +- .../strategy/code/strategy_verification.go | 2 +- .../strategy/link/strategy_recovery.go | 4 +- .../strategy/link/strategy_verification.go | 4 +- selfservice/strategy/lookup/login.go | 2 +- selfservice/strategy/lookup/settings.go | 4 +- selfservice/strategy/oidc/strategy_login.go | 4 +- .../strategy/oidc/strategy_registration.go | 4 +- selfservice/strategy/password/login.go | 2 +- selfservice/strategy/password/registration.go | 2 +- selfservice/strategy/password/settings.go | 4 +- selfservice/strategy/profile/strategy.go | 4 +- selfservice/strategy/totp/login.go | 2 +- selfservice/strategy/totp/settings.go | 6 +- selfservice/strategy/webauthn/login.go | 2 +- selfservice/strategy/webauthn/registration.go | 2 +- selfservice/strategy/webauthn/settings.go | 4 +- 27 files changed, 206 insertions(+), 77 deletions(-) diff --git a/driver/config/config.go b/driver/config/config.go index a2928858dcb2..6a085c517873 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -224,9 +224,13 @@ type ( Config json.RawMessage `json:"config"` } SelfServiceStrategy struct { - Enabled bool `json:"enabled"` - Config json.RawMessage `json:"config"` - AllowedFlows []string `json:"allowed_flows"` + Enabled bool `json:"enabled"` + Config json.RawMessage `json:"config"` + } + SelfServiceStrategyCode struct { + RegistrationEnabled bool `json:"registration_enabled"` + LoginEnabled bool `json:"login_enabled"` + *SelfServiceStrategy } Schema struct { ID string `json:"id" koanf:"id"` @@ -732,9 +736,8 @@ func (p *Config) SelfServiceStrategy(ctx context.Context, strategy string) *Self enabledKey := fmt.Sprintf("%s.enabled", basePath) s := &SelfServiceStrategy{ - Enabled: pp.Bool(enabledKey), - Config: json.RawMessage(config), - AllowedFlows: pp.Strings(fmt.Sprintf("%s.allowed_flows", basePath)), + Enabled: pp.Bool(enabledKey), + Config: json.RawMessage(config), } // The default value can easily be overwritten by setting e.g. `{"selfservice": "null"}` which means that @@ -758,6 +761,40 @@ func (p *Config) SelfServiceStrategy(ctx context.Context, strategy string) *Self return s } +func (p *Config) SelfServiceCodeStrategy(ctx context.Context) *SelfServiceStrategyCode { + pp := p.GetProvider(ctx) + + config := "{}" + out, err := pp.Marshal(kjson.Parser()) + if err != nil { + p.l.WithError(err).Warn("Unable to marshal self service strategy configuration.") + } else if c := gjson.GetBytes(out, + fmt.Sprintf("%s.%s.config", ViperKeySelfServiceStrategyConfig, "code")).Raw; len(c) > 0 { + config = c + } + + basePath := fmt.Sprintf("%s.%s", ViperKeySelfServiceStrategyConfig, "code") + enabledKey := fmt.Sprintf("%s.enabled", basePath) + registrationKey := fmt.Sprintf("%s.registration_enabled", basePath) + loginKey := fmt.Sprintf("%s.login_enabled", basePath) + + s := &SelfServiceStrategyCode{ + SelfServiceStrategy: &SelfServiceStrategy{ + Enabled: pp.Bool(enabledKey), + Config: json.RawMessage(config), + }, + RegistrationEnabled: pp.Bool(registrationKey), + LoginEnabled: pp.Bool(loginKey), + } + + if !pp.Exists(enabledKey) { + s.RegistrationEnabled = false + s.LoginEnabled = false + s.Enabled = true + } + return s +} + func (p *Config) SecretsDefault(ctx context.Context) [][]byte { pp := p.GetProvider(ctx) secrets := pp.Strings(ViperKeySecretsDefault) diff --git a/driver/registry_default.go b/driver/registry_default.go index 3b059c7e29fc..bc7a31e76c2b 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -327,16 +327,28 @@ func (m *RegistryDefault) selfServiceStrategies() []interface{} { return m.selfserviceStrategies } +func (m *RegistryDefault) strategyRegistrationEnabled(ctx context.Context, id string) bool { + switch id { + case identity.CredentialsTypeCodeAuth.String(): + return m.Config().SelfServiceCodeStrategy(ctx).RegistrationEnabled + default: + return m.Config().SelfServiceStrategy(ctx, id).Enabled + } +} + +func (m *RegistryDefault) strategyLoginEnabled(ctx context.Context, id string) bool { + switch id { + case identity.CredentialsTypeCodeAuth.String(): + return m.Config().SelfServiceCodeStrategy(ctx).LoginEnabled + default: + return m.Config().SelfServiceStrategy(ctx, id).Enabled + } +} + func (m *RegistryDefault) RegistrationStrategies(ctx context.Context) (registrationStrategies registration.Strategies) { for _, strategy := range m.selfServiceStrategies() { if s, ok := strategy.(registration.Strategy); ok { - // the code method needs to be checked explicitly for registration - // TODO: we need to somehow check if the `code` strategy is enabled specifically for registration - // if s.ID() == identity.CredentialsTypeCodeAuth && m.Config().SelfServiceStrategy(ctx, string(s.ID())).RegistrationEnabled { - // registrationStrategies = append(registrationStrategies, s) - // continue - // } - if m.Config().SelfServiceStrategy(ctx, string(s.ID())).Enabled { + if m.strategyRegistrationEnabled(ctx, s.ID().String()) { registrationStrategies = append(registrationStrategies, s) } } @@ -358,13 +370,7 @@ func (m *RegistryDefault) AllRegistrationStrategies() registration.Strategies { func (m *RegistryDefault) LoginStrategies(ctx context.Context) (loginStrategies login.Strategies) { for _, strategy := range m.selfServiceStrategies() { if s, ok := strategy.(login.Strategy); ok { - // the code method needs to be checked explicity for login - // TODO: we need to somwhow check if the `code` strategy is enabled specifically for login - // if s.ID() == identity.CredentialsTypeCodeAuth && m.Config().SelfServiceStrategy(ctx, string(s.ID())).LoginEnabled { - // loginStrategies = append(loginStrategies, s) - // continue - // } - if m.Config().SelfServiceStrategy(ctx, string(s.ID())).Enabled { + if m.strategyLoginEnabled(ctx, s.ID().String()) { loginStrategies = append(loginStrategies, s) } } diff --git a/driver/registry_default_registration.go b/driver/registry_default_registration.go index 306076795824..060afcbdf5c7 100644 --- a/driver/registry_default_registration.go +++ b/driver/registry_default_registration.go @@ -28,8 +28,7 @@ func (m *RegistryDefault) PostRegistrationPostPersistHooks(ctx context.Context, initialHookCount = 1 } - // TODO: this needs to be specific to the flow and not just the `code` general strategy - if m.Config().SelfServiceStrategy(ctx, identity.CredentialsTypeCodeAuth.String()).Enabled { + if credentialsType == identity.CredentialsTypeCodeAuth && m.Config().SelfServiceCodeStrategy(ctx).RegistrationEnabled { b = append(b, m.HookCodeAddressVerifier()) initialHookCount += 1 } diff --git a/embedx/config.schema.json b/embedx/config.schema.json index bded3ad9c74d..2296fe657995 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -1290,6 +1290,16 @@ "type": "object", "additionalProperties": false, "properties": { + "login_enabled": { + "type": "boolean", + "title": "Enables Login with Code Method", + "default": false + }, + "registration_method": { + "type": "boolean", + "title": "Enables Registration with Code Method", + "default": false + }, "enabled": { "type": "boolean", "title": "Enables Code Method", diff --git a/selfservice/flow/request.go b/selfservice/flow/request.go index beca893e7534..3a68673af580 100644 --- a/selfservice/flow/request.go +++ b/selfservice/flow/request.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/strategy" "github.com/ory/x/decoderx" @@ -77,7 +78,7 @@ func EnsureCSRF(reg interface { var dec = decoderx.NewHTTP() -func MethodEnabledAndAllowedFromRequest(r *http.Request, expected string, d interface { +func MethodEnabledAndAllowedFromRequest(r *http.Request, flow FlowName, expected string, d interface { config.Provider }, ) error { @@ -98,13 +99,13 @@ func MethodEnabledAndAllowedFromRequest(r *http.Request, expected string, d inte return errors.WithStack(err) } - return MethodEnabledAndAllowed(r.Context(), expected, method.Method, d) + return MethodEnabledAndAllowed(r.Context(), flow, expected, method.Method, d) } // TODO: to disable specific flows we need to pass down the flow somehow to this method // we could do this by adding an additional parameter, but not all methods have access to the flow // this adds a lot of refactoring work, so we should think about a better way to do this -func MethodEnabledAndAllowed(ctx context.Context, expected, actual string, d interface { +func MethodEnabledAndAllowed(ctx context.Context, flowName FlowName, expected, actual string, d interface { config.Provider }, ) error { @@ -112,27 +113,24 @@ func MethodEnabledAndAllowed(ctx context.Context, expected, actual string, d int return errors.WithStack(ErrStrategyNotResponsible) } - stratConf := d.Config().SelfServiceStrategy(ctx, expected) + var ok bool - err := herodot.ErrNotFound.WithReason(strategy.EndpointDisabledMessage) + if strings.EqualFold(actual, identity.CredentialsTypeCodeAuth.String()) { + switch flowName { + case RegistrationFlow: + ok = d.Config().SelfServiceCodeStrategy(ctx).RegistrationEnabled + case LoginFlow: + ok = d.Config().SelfServiceCodeStrategy(ctx).LoginEnabled + default: + ok = d.Config().SelfServiceCodeStrategy(ctx).Enabled + } + } else { + ok = d.Config().SelfServiceStrategy(ctx, expected).Enabled + } - if stratConf.Enabled { - return nil + if !ok { + return herodot.ErrNotFound.WithReason(strategy.EndpointDisabledMessage) } - // TODO: Implement a way to disable specific flows for this strategy - // For example, with Code strategy we might only want to allow login flows or recovery flows. - - // if len(stratConf.AllowedFlows) == 0 { - // return nil - // } - // - // // must match one of the allowed flows since this method is enabled - // for _, s := range stratConf.AllowedFlows { - // if strings.EqualFold(s, string(flow.GetFlowName())) { - // return nil - // } - // } - - return errors.WithStack(err) + return nil } diff --git a/selfservice/flow/request_test.go b/selfservice/flow/request_test.go index d240f9c15daf..4fa39a61bc46 100644 --- a/selfservice/flow/request_test.go +++ b/selfservice/flow/request_test.go @@ -55,7 +55,7 @@ func TestMethodEnabledAndAllowed(t *testing.T) { ctx := context.Background() conf, d := internal.NewFastRegistryWithMocks(t) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := flow.MethodEnabledAndAllowedFromRequest(r, "password", d); err != nil { + if err := flow.MethodEnabledAndAllowedFromRequest(r, flow.LoginFlow, "password", d); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -91,3 +91,75 @@ func TestMethodEnabledAndAllowed(t *testing.T) { assert.Contains(t, string(body), "The requested resource could not be found") }) } + +func TestMethodCodeEnabledAndAllowed(t *testing.T) { + ctx := context.Background() + conf, d := internal.NewFastRegistryWithMocks(t) + + currentFlow := make(chan flow.FlowName, 1) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + f := <-currentFlow + if err := flow.MethodEnabledAndAllowedFromRequest(r, f, "code", d); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNoContent) + })) + + t.Run("login code allowed", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.login_enabled", true) + currentFlow <- flow.LoginFlow + res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + assert.Equal(t, http.StatusNoContent, res.StatusCode) + }) + + t.Run("login code not allowed", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.login_enabled", false) + currentFlow <- flow.LoginFlow + res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) + require.NoError(t, err) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Contains(t, string(body), "The requested resource could not be found") + }) + + t.Run("registration code allowed", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.registration_enabled", true) + currentFlow <- flow.RegistrationFlow + res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + assert.Equal(t, http.StatusNoContent, res.StatusCode) + }) + + t.Run("registration code not allowed", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.registration_enabled", false) + currentFlow <- flow.RegistrationFlow + res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) + require.NoError(t, err) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Contains(t, string(body), "The requested resource could not be found") + }) + + t.Run("recovery and verification should still be allowed if registration and login is disabled", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.registration_enabled", false) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.login_enabled", false) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true) + + for _, f := range []flow.FlowName{flow.RecoveryFlow, flow.VerificationFlow} { + currentFlow <- f + res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + assert.Equal(t, http.StatusNoContent, res.StatusCode) + } + }) +} diff --git a/selfservice/hook/code_address_verifier.go b/selfservice/hook/code_address_verifier.go index e9f15eaf3d73..d2d50e4be9ff 100644 --- a/selfservice/hook/code_address_verifier.go +++ b/selfservice/hook/code_address_verifier.go @@ -1,8 +1,13 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package hook import ( "net/http" + "github.com/pkg/errors" + "github.com/ory/kratos/driver/config" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/registration" @@ -10,7 +15,6 @@ import ( "github.com/ory/kratos/selfservice/strategy/code" "github.com/ory/kratos/session" "github.com/ory/kratos/x" - "github.com/pkg/errors" ) type ( diff --git a/selfservice/strategy/code/strategy_login.go b/selfservice/strategy/code/strategy_login.go index 30df0884dfc9..0e874a6b205f 100644 --- a/selfservice/strategy/code/strategy_login.go +++ b/selfservice/strategy/code/strategy_login.go @@ -68,7 +68,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, err } - if err := flow.MethodEnabledAndAllowedFromRequest(r, s.ID().String(), s.deps); err != nil { + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.deps); err != nil { return nil, err } diff --git a/selfservice/strategy/code/strategy_recovery.go b/selfservice/strategy/code/strategy_recovery.go index a7f48b29711c..7fd946c593ae 100644 --- a/selfservice/strategy/code/strategy_recovery.go +++ b/selfservice/strategy/code/strategy_recovery.go @@ -311,7 +311,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F // If the email is present in the submission body, the user needs a new code via resend if f.State != flow.StateChooseMethod && len(body.Email) == 0 { - if err := flow.MethodEnabledAndAllowed(ctx, sID, sID, s.deps); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, flow.RecoveryFlow, sID, sID, s.deps); err != nil { return s.HandleRecoveryError(w, r, nil, body, err) } return s.recoveryUseCode(w, r, body, f) @@ -327,7 +327,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F return errors.WithStack(flow.ErrCompletedByStrategy) } - if err := flow.MethodEnabledAndAllowed(ctx, sID, body.Method, s.deps); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, flow.RecoveryFlow, sID, body.Method, s.deps); err != nil { return s.HandleRecoveryError(w, r, nil, body, err) } diff --git a/selfservice/strategy/code/strategy_registration.go b/selfservice/strategy/code/strategy_registration.go index 46addf9886b6..cc406a521652 100644 --- a/selfservice/strategy/code/strategy_registration.go +++ b/selfservice/strategy/code/strategy_registration.go @@ -118,7 +118,9 @@ func (s *Strategy) handleIdentityTraits(ctx context.Context, f *registration.Flo } func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) error { - if err := flow.MethodEnabledAndAllowedFromRequest(r, s.ID().String(), s.deps); err != nil { + if !s.deps.Config().SelfServiceCodeStrategy(r.Context()).RegistrationEnabled { + } + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.deps); err != nil { return err } diff --git a/selfservice/strategy/code/strategy_registration_test.go b/selfservice/strategy/code/strategy_registration_test.go index e942c1840c8f..579c873adb10 100644 --- a/selfservice/strategy/code/strategy_registration_test.go +++ b/selfservice/strategy/code/strategy_registration_test.go @@ -32,7 +32,8 @@ func TestRegistrationCodeStrategy(t *testing.T) { ctx := context.Background() conf, reg := internal.NewFastRegistryWithMocks(t) testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/registration.schema.json") - conf.MustSet(ctx, fmt.Sprintf("%s.%s.registration_enabled", config.ViperKeySelfServiceStrategyConfig, string(identity.CredentialsTypeCodeAuth)), false) + conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth.String()), false) + conf.MustSet(ctx, fmt.Sprintf("%s.%s.registration_enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth), true) conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh") conf.MustSet(ctx, config.ViperKeyURLsAllowedReturnToDomains, []string{"https://www.ory.sh"}) conf.MustSet(ctx, config.ViperKeySelfServiceRegistrationAfter+".code.hooks", []map[string]interface{}{ diff --git a/selfservice/strategy/code/strategy_verification.go b/selfservice/strategy/code/strategy_verification.go index fe63a523cda9..aea07b9f6ae9 100644 --- a/selfservice/strategy/code/strategy_verification.go +++ b/selfservice/strategy/code/strategy_verification.go @@ -128,7 +128,7 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio return s.handleVerificationError(w, r, nil, body, err) } - if err := flow.MethodEnabledAndAllowed(r.Context(), s.VerificationStrategyID(), string(body.getMethod()), s.deps); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), string(body.getMethod()), s.deps); err != nil { return s.handleVerificationError(w, r, f, body, err) } diff --git a/selfservice/strategy/link/strategy_recovery.go b/selfservice/strategy/link/strategy_recovery.go index 81d8dee363fc..40d21eb89764 100644 --- a/selfservice/strategy/link/strategy_recovery.go +++ b/selfservice/strategy/link/strategy_recovery.go @@ -237,7 +237,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F } if len(body.Token) > 0 { - if err := flow.MethodEnabledAndAllowed(r.Context(), s.RecoveryStrategyID(), s.RecoveryStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.RecoveryStrategyID(), s.RecoveryStrategyID(), s.d); err != nil { return s.HandleRecoveryError(w, r, nil, body, err) } @@ -253,7 +253,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F return errors.WithStack(flow.ErrCompletedByStrategy) } - if err := flow.MethodEnabledAndAllowed(r.Context(), s.RecoveryStrategyID(), body.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.RecoveryStrategyID(), body.Method, s.d); err != nil { return s.HandleRecoveryError(w, r, nil, body, err) } diff --git a/selfservice/strategy/link/strategy_verification.go b/selfservice/strategy/link/strategy_verification.go index 807eb81023a9..804219531d8c 100644 --- a/selfservice/strategy/link/strategy_verification.go +++ b/selfservice/strategy/link/strategy_verification.go @@ -122,14 +122,14 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio } if len(body.Token) > 0 { - if err := flow.MethodEnabledAndAllowed(r.Context(), s.VerificationStrategyID(), s.VerificationStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), s.VerificationStrategyID(), s.d); err != nil { return s.handleVerificationError(w, r, nil, body, err) } return s.verificationUseToken(w, r, body, f) } - if err := flow.MethodEnabledAndAllowed(r.Context(), s.VerificationStrategyID(), body.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), body.Method, s.d); err != nil { return s.handleVerificationError(w, r, f, body, err) } diff --git a/selfservice/strategy/lookup/login.go b/selfservice/strategy/lookup/login.go index b5e5fddda7e8..75edc2181059 100644 --- a/selfservice/strategy/lookup/login.go +++ b/selfservice/strategy/lookup/login.go @@ -94,7 +94,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, err } - if err := flow.MethodEnabledAndAllowedFromRequest(r, s.ID().String(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.d); err != nil { return nil, err } diff --git a/selfservice/strategy/lookup/settings.go b/selfservice/strategy/lookup/settings.go index 261336ecdcbc..6f61966e353c 100644 --- a/selfservice/strategy/lookup/settings.go +++ b/selfservice/strategy/lookup/settings.go @@ -108,7 +108,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. if p.RegenerateLookup || p.RevealLookup || p.ConfirmLookup || p.DisableLookup { // This method has only two submit buttons p.Method = s.SettingsStrategyID() - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleSettingsError(w, r, ctxUpdate, &p, err) } } else { @@ -141,7 +141,7 @@ func (s *Strategy) continueSettingsFlow( ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithLookupMethod, ) error { if p.ConfirmLookup || p.RevealLookup || p.RegenerateLookup || p.DisableLookup { - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return err } diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index edf3faeb80ae..da21342033bc 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -159,12 +159,12 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, s.handleError(w, r, f, "", nil, errors.WithStack(herodot.ErrBadRequest.WithDebug(err.Error()).WithReasonf("Unable to parse HTTP form request: %s", err.Error()))) } - var pid = p.Provider // this can come from both url query and post body + pid := p.Provider // this can come from both url query and post body if pid == "" { return nil, errors.WithStack(flow.ErrStrategyNotResponsible) } - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return nil, s.handleError(w, r, f, pid, nil, err) } diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index 1e303c50df35..690d4650fa3a 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -127,12 +127,12 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat f.TransientPayload = p.TransientPayload - var pid = p.Provider // this can come from both url query and post body + pid := p.Provider // this can come from both url query and post body if pid == "" { return errors.WithStack(flow.ErrStrategyNotResponsible) } - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return s.handleError(w, r, f, pid, nil, err) } diff --git a/selfservice/strategy/password/login.go b/selfservice/strategy/password/login.go index 30010b80eb54..7a56ebaf45d4 100644 --- a/selfservice/strategy/password/login.go +++ b/selfservice/strategy/password/login.go @@ -51,7 +51,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, err } - if err := flow.MethodEnabledAndAllowedFromRequest(r, s.ID().String(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.d); err != nil { return nil, err } diff --git a/selfservice/strategy/password/registration.go b/selfservice/strategy/password/registration.go index 57ebca990923..ac5c91789745 100644 --- a/selfservice/strategy/password/registration.go +++ b/selfservice/strategy/password/registration.go @@ -78,7 +78,7 @@ func (s *Strategy) decode(p *UpdateRegistrationFlowWithPasswordMethod, r *http.R } func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) (err error) { - if err := flow.MethodEnabledAndAllowedFromRequest(r, s.ID().String(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.d); err != nil { return err } diff --git a/selfservice/strategy/password/settings.go b/selfservice/strategy/password/settings.go index 45e6bc045650..4ba0b115e53a 100644 --- a/selfservice/strategy/password/settings.go +++ b/selfservice/strategy/password/settings.go @@ -75,7 +75,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) } - if err := flow.MethodEnabledAndAllowedFromRequest(r, s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.SettingsStrategyID(), s.d); err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) } @@ -109,7 +109,7 @@ func (s *Strategy) continueSettingsFlow( w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithPasswordMethod, ) error { - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { return err } diff --git a/selfservice/strategy/profile/strategy.go b/selfservice/strategy/profile/strategy.go index 5c4a68be9a78..e94d779aef61 100644 --- a/selfservice/strategy/profile/strategy.go +++ b/selfservice/strategy/profile/strategy.go @@ -116,7 +116,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, &p, err) } - if err := flow.MethodEnabledAndAllowedFromRequest(r, s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.SettingsStrategyID(), s.d); err != nil { return ctxUpdate, err } @@ -144,7 +144,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. } func (s *Strategy) continueFlow(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithProfileMethod) error { - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { return err } diff --git a/selfservice/strategy/totp/login.go b/selfservice/strategy/totp/login.go index e3840eba1b9b..bc9816d265f3 100644 --- a/selfservice/strategy/totp/login.go +++ b/selfservice/strategy/totp/login.go @@ -90,7 +90,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, err } - if err := flow.MethodEnabledAndAllowedFromRequest(r, s.ID().String(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.d); err != nil { return nil, err } diff --git a/selfservice/strategy/totp/settings.go b/selfservice/strategy/totp/settings.go index 928c288be365..0587e87e2d6a 100644 --- a/selfservice/strategy/totp/settings.go +++ b/selfservice/strategy/totp/settings.go @@ -94,10 +94,10 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. if p.UnlinkTOTP { // This is a submit so we need to manually set the type to TOTP p.Method = s.SettingsStrategyID() - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleSettingsError(w, r, ctxUpdate, &p, err) } - } else if err := flow.MethodEnabledAndAllowedFromRequest(r, s.SettingsStrategyID(), s.d); err != nil { + } else if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.SettingsStrategyID(), s.d); err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, &p, err) } @@ -127,7 +127,7 @@ func (s *Strategy) continueSettingsFlow( w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithTotpMethod, ) error { - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { return err } diff --git a/selfservice/strategy/webauthn/login.go b/selfservice/strategy/webauthn/login.go index 33e2689ea315..857f9c55431d 100644 --- a/selfservice/strategy/webauthn/login.go +++ b/selfservice/strategy/webauthn/login.go @@ -211,7 +211,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, flow.ErrStrategyNotResponsible } - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleLoginError(r, f, err) } diff --git a/selfservice/strategy/webauthn/registration.go b/selfservice/strategy/webauthn/registration.go index 96b67244abaa..2880e3ecf9dc 100644 --- a/selfservice/strategy/webauthn/registration.go +++ b/selfservice/strategy/webauthn/registration.go @@ -113,7 +113,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat } p.Method = s.SettingsStrategyID() - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return s.handleRegistrationError(w, r, f, &p, err) } diff --git a/selfservice/strategy/webauthn/settings.go b/selfservice/strategy/webauthn/settings.go index a6e513a20c99..7a663f6fa592 100644 --- a/selfservice/strategy/webauthn/settings.go +++ b/selfservice/strategy/webauthn/settings.go @@ -112,7 +112,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. if len(p.Register+p.Remove) > 0 { // This method has only two submit buttons p.Method = s.SettingsStrategyID() - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleSettingsError(w, r, ctxUpdate, &p, err) } } else { @@ -146,7 +146,7 @@ func (s *Strategy) continueSettingsFlow( ctxUpdate *settings.UpdateContext, p *updateSettingsFlowWithWebAuthnMethod, ) error { if len(p.Register+p.Remove) > 0 { - if err := flow.MethodEnabledAndAllowed(r.Context(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return err }