From 42e8aef127fffc3fc253e7a417643e41af40d028 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Wed, 11 Dec 2024 14:24:45 +0100 Subject: [PATCH] code review --- embedx/config.schema.json | 3 ++- selfservice/strategy/oidc/strategy.go | 21 +++++++++++++-------- selfservice/strategy/oidc/strategy_login.go | 8 -------- selfservice/strategy/oidc/strategy_test.go | 13 ------------- 4 files changed, 15 insertions(+), 30 deletions(-) diff --git a/embedx/config.schema.json b/embedx/config.schema.json index bf06568ebe3f..5fbe0e90bdcb 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -678,7 +678,8 @@ "examples": ["be2b30", "my-saml-provider"] }, "label": { - "title": "Optional string which will be used when generating labels for UI buttons.", + "title": "Label", + "description": "Optional string which will be used when generating labels for UI buttons.", "type": "string" }, "mapper_url": { diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index e58924af14c5..f1ed22341626 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -118,10 +118,11 @@ func isForced(req interface{}) bool { // Strategy implements selfservice.LoginStrategy, selfservice.RegistrationStrategy and selfservice.SettingsStrategy. // It supports login, registration and settings via OpenID Providers. type Strategy struct { - d Dependencies - validator *schema.Validator - dec *decoderx.HTTP - credType identity.CredentialsType + d Dependencies + validator *schema.Validator + dec *decoderx.HTTP + credType identity.CredentialsType + handleUnknownProviderError func(err error) error } type AuthCodeContainer struct { @@ -208,12 +209,16 @@ type NewStrategyOpt func(s *Strategy) func ForCredentialType(ct identity.CredentialsType) NewStrategyOpt { return func(s *Strategy) { s.credType = ct } } +func WithUnknownProviderHandler(handler func(error) error) NewStrategyOpt { + return func(s *Strategy) { s.handleUnknownProviderError = handler } +} func NewStrategy(d any, opts ...NewStrategyOpt) *Strategy { s := &Strategy{ - d: d.(Dependencies), - validator: schema.NewValidator(), - credType: identity.CredentialsTypeOIDC, + d: d.(Dependencies), + validator: schema.NewValidator(), + credType: identity.CredentialsTypeOIDC, + handleUnknownProviderError: func(err error) error { return err }, } for _, opt := range opts { opt(s) @@ -542,7 +547,7 @@ func (s *Strategy) provider(ctx context.Context, id string) (Provider, error) { if c, err := s.Config(ctx); err != nil { return nil, err } else if provider, err := c.Provider(id, s.d); err != nil { - return nil, err + return nil, s.handleUnknownProviderError(err) } else { return provider, nil } diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 8bbea3a3b66b..4ec58a25590b 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -206,14 +206,6 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, errors.WithStack(flow.ErrStrategyNotResponsible) } - // This is a hack for a lack of a `method` field in the form body. - if prefix, _, ok := strings.Cut(pid, ":"); ok { - if prefix != s.ID().String() { - span.SetAttributes(attribute.String("not_responsible_reason", "provider ID prefix does not match strategy")) - return nil, errors.WithStack(flow.ErrStrategyNotResponsible) - } - } - if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" { // the user is sending a method that is not oidc, but the payload includes a provider s.d.Audit(). diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 50797e8b57e1..dde3e1bda607 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -90,7 +90,6 @@ func TestStrategy(t *testing.T) { newOIDCProvider(t, ts, remotePublic, remoteAdmin, "valid"), newOIDCProvider(t, ts, remotePublic, remoteAdmin, "valid2"), newOIDCProvider(t, ts, remotePublic, remoteAdmin, "secondProvider"), - newOIDCProvider(t, ts, remotePublic, remoteAdmin, "saml:provider"), newOIDCProvider(t, ts, remotePublic, remoteAdmin, "claimsViaUserInfo", func(c *oidc.Configuration) { c.ClaimsSource = oidc.ClaimsSourceUserInfo }), @@ -339,18 +338,6 @@ func TestStrategy(t *testing.T) { } }) - t.Run("case=should redirect to UI because strategy not responsible", func(t *testing.T) { - for k, v := range []string{ - loginAction(newBrowserLoginFlow(t, returnTS.URL, time.Minute).ID), - registerAction(newBrowserRegistrationFlow(t, returnTS.URL, time.Minute).ID), - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - res, body := makeRequest(t, "saml:provider", v, url.Values{}) - assert.Contains(t, res.Request.URL.String(), uiTS.URL, "Redirect does not point to UI server. Status: %d, body: %s", res.StatusCode, body) - }) - } - }) - t.Run("case=should fail because flow does not exist", func(t *testing.T) { for k, v := range []string{loginAction(x.NewUUID()), registerAction(x.NewUUID())} { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {