Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Dec 11, 2024
1 parent aec9ff9 commit 42e8aef
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 30 deletions.
3 changes: 2 additions & 1 deletion embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,8 @@
"examples": ["be2b30", "my-saml-provider"]
},
"label": {
"title": "Optional string which will be used when generating labels for UI buttons.",
"title": "Label",
"description": "Optional string which will be used when generating labels for UI buttons.",
"type": "string"
},
"mapper_url": {
Expand Down
21 changes: 13 additions & 8 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ func isForced(req interface{}) bool {
// Strategy implements selfservice.LoginStrategy, selfservice.RegistrationStrategy and selfservice.SettingsStrategy.
// It supports login, registration and settings via OpenID Providers.
type Strategy struct {
d Dependencies
validator *schema.Validator
dec *decoderx.HTTP
credType identity.CredentialsType
d Dependencies
validator *schema.Validator
dec *decoderx.HTTP
credType identity.CredentialsType
handleUnknownProviderError func(err error) error
}

type AuthCodeContainer struct {
Expand Down Expand Up @@ -208,12 +209,16 @@ type NewStrategyOpt func(s *Strategy)
func ForCredentialType(ct identity.CredentialsType) NewStrategyOpt {
return func(s *Strategy) { s.credType = ct }
}
func WithUnknownProviderHandler(handler func(error) error) NewStrategyOpt {
return func(s *Strategy) { s.handleUnknownProviderError = handler }

Check warning on line 213 in selfservice/strategy/oidc/strategy.go

View check run for this annotation

Codecov / codecov/patch

selfservice/strategy/oidc/strategy.go#L212-L213

Added lines #L212 - L213 were not covered by tests
}

func NewStrategy(d any, opts ...NewStrategyOpt) *Strategy {
s := &Strategy{
d: d.(Dependencies),
validator: schema.NewValidator(),
credType: identity.CredentialsTypeOIDC,
d: d.(Dependencies),
validator: schema.NewValidator(),
credType: identity.CredentialsTypeOIDC,
handleUnknownProviderError: func(err error) error { return err },
}
for _, opt := range opts {
opt(s)
Expand Down Expand Up @@ -542,7 +547,7 @@ func (s *Strategy) provider(ctx context.Context, id string) (Provider, error) {
if c, err := s.Config(ctx); err != nil {
return nil, err
} else if provider, err := c.Provider(id, s.d); err != nil {
return nil, err
return nil, s.handleUnknownProviderError(err)
} else {
return provider, nil
}
Expand Down
8 changes: 0 additions & 8 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,6 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
return nil, errors.WithStack(flow.ErrStrategyNotResponsible)
}

// This is a hack for a lack of a `method` field in the form body.
if prefix, _, ok := strings.Cut(pid, ":"); ok {
if prefix != s.ID().String() {
span.SetAttributes(attribute.String("not_responsible_reason", "provider ID prefix does not match strategy"))
return nil, errors.WithStack(flow.ErrStrategyNotResponsible)
}
}

if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" {
// the user is sending a method that is not oidc, but the payload includes a provider
s.d.Audit().
Expand Down
13 changes: 0 additions & 13 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ func TestStrategy(t *testing.T) {
newOIDCProvider(t, ts, remotePublic, remoteAdmin, "valid"),
newOIDCProvider(t, ts, remotePublic, remoteAdmin, "valid2"),
newOIDCProvider(t, ts, remotePublic, remoteAdmin, "secondProvider"),
newOIDCProvider(t, ts, remotePublic, remoteAdmin, "saml:provider"),
newOIDCProvider(t, ts, remotePublic, remoteAdmin, "claimsViaUserInfo", func(c *oidc.Configuration) {
c.ClaimsSource = oidc.ClaimsSourceUserInfo
}),
Expand Down Expand Up @@ -339,18 +338,6 @@ func TestStrategy(t *testing.T) {
}
})

t.Run("case=should redirect to UI because strategy not responsible", func(t *testing.T) {
for k, v := range []string{
loginAction(newBrowserLoginFlow(t, returnTS.URL, time.Minute).ID),
registerAction(newBrowserRegistrationFlow(t, returnTS.URL, time.Minute).ID),
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
res, body := makeRequest(t, "saml:provider", v, url.Values{})
assert.Contains(t, res.Request.URL.String(), uiTS.URL, "Redirect does not point to UI server. Status: %d, body: %s", res.StatusCode, body)
})
}
})

t.Run("case=should fail because flow does not exist", func(t *testing.T) {
for k, v := range []string{loginAction(x.NewUUID()), registerAction(x.NewUUID())} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
Expand Down

0 comments on commit 42e8aef

Please sign in to comment.