Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: jackson provider #4242

Merged
merged 16 commits into from
Dec 19, 2024
5 changes: 5 additions & 0 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ const (
ViperKeyIgnoreNetworkErrors = "selfservice.methods.password.config.ignore_network_errors"
ViperKeyTOTPIssuer = "selfservice.methods.totp.config.issuer"
ViperKeyOIDCBaseRedirectURL = "selfservice.methods.oidc.config.base_redirect_uri"
ViperKeySAMLBaseRedirectURL = "selfservice.methods.saml.config.base_redirect_uri"
ViperKeyWebAuthnRPDisplayName = "selfservice.methods.webauthn.config.rp.display_name"
ViperKeyWebAuthnRPID = "selfservice.methods.webauthn.config.rp.id"
ViperKeyWebAuthnRPOrigin = "selfservice.methods.webauthn.config.rp.origin"
Expand Down Expand Up @@ -616,6 +617,10 @@ func (p *Config) OIDCRedirectURIBase(ctx context.Context) *url.URL {
return p.GetProvider(ctx).URIF(ViperKeyOIDCBaseRedirectURL, p.SelfPublicURL(ctx))
}

func (p *Config) SAMLRedirectURIBase(ctx context.Context) *url.URL {
return p.GetProvider(ctx).URIF(ViperKeySAMLBaseRedirectURL, p.SelfPublicURL(ctx))
}

func (p *Config) IdentityTraitsSchemas(ctx context.Context) (ss Schemas, err error) {
if err = p.GetProvider(ctx).Koanf.Unmarshal(ViperKeyIdentitySchemas, &ss); err != nil {
return ss, nil
Expand Down
6 changes: 2 additions & 4 deletions embedx/embedx.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ package embedx

import (
"bytes"
_ "embed"
"io"

"github.com/pkg/errors"

"github.com/ory/x/otelx"

"github.com/tidwall/gjson"

_ "embed"
"github.com/ory/x/otelx"
)

//go:embed config.schema.json
Expand Down
4 changes: 3 additions & 1 deletion identity/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ const (
CredentialsTypeCodeAuth CredentialsType = "code"
CredentialsTypePasskey CredentialsType = "passkey"
CredentialsTypeProfile CredentialsType = "profile"
CredentialsTypeSAML CredentialsType = "saml"
)

func (c CredentialsType) String() string {
Expand All @@ -99,7 +100,7 @@ func (c CredentialsType) ToUiNodeGroup() node.UiNodeGroup {
switch c {
case CredentialsTypePassword:
return node.PasswordGroup
case CredentialsTypeOIDC:
case CredentialsTypeOIDC, CredentialsTypeSAML:
return node.OpenIDConnectGroup
case CredentialsTypeTOTP:
return node.TOTPGroup
Expand Down Expand Up @@ -138,6 +139,7 @@ func ParseCredentialsType(in string) (CredentialsType, bool) {
for _, t := range []CredentialsType{
CredentialsTypePassword,
CredentialsTypeOIDC,
CredentialsTypeSAML,
CredentialsTypeTOTP,
CredentialsTypeLookup,
CredentialsTypeWebAuthn,
Expand Down
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ func FindIdentityCredentialsTypeByName(con *pop.Connection, ct identity.Credenti
}

if !found {
return uuid.Nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type for nane %s. This is a bug in the code.", ct))
return uuid.Nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type for name %q. This is a bug in the code.", ct))
}

return result, nil
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ preLoginHook:
// We only apply the filter on AAL1, because the OIDC strategy can only satsify
// AAL1.
strategyFilters = []StrategyFilter{func(s Strategy) bool {
return s.ID() == identity.CredentialsTypeOIDC
return s.ID() == identity.CredentialsTypeOIDC || s.ID() == identity.CredentialsTypeSAML
}}
}
}
Expand Down
4 changes: 3 additions & 1 deletion selfservice/flow/registration/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ func (h *Handler) NewRegistrationFlow(w http.ResponseWriter, r *http.Request, ft
h.d.Logger().WithError(err).Warnf("ignoring invalid UUID %q in query parameter `organization`", rawOrg)
} else {
f.OrganizationID = uuid.NullUUID{UUID: orgID, Valid: true}
strategyFilters = []StrategyFilter{func(s Strategy) bool { return s.ID() == identity.CredentialsTypeOIDC }}
strategyFilters = []StrategyFilter{func(s Strategy) bool {
return s.ID() == identity.CredentialsTypeOIDC || s.ID() == identity.CredentialsTypeSAML
}}
}
}
for _, s := range h.d.RegistrationStrategies(r.Context(), strategyFilters...) {
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/registration/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ func TestOIDCStrategyOrder(t *testing.T) {

// reorder the strategies
reg.WithSelfserviceStrategies(t, []any{
oidc.NewStrategy(reg),
oidc.NewStrategy(reg, oidc.ForCredentialType(identity.CredentialsTypeOIDC)),
password.NewStrategy(reg),
})

Expand Down
1 change: 1 addition & 0 deletions selfservice/strategy/oidc/provider_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ var supportedProviders = map[string]func(config *Configuration, reg Dependencies
"patreon": NewProviderPatreon,
"lark": NewProviderLark,
"x": NewProviderX,
"jackson": NewProviderJackson,
}

func (c ConfigurationCollection) Provider(id string, reg Dependencies) (Provider, error) {
Expand Down
57 changes: 57 additions & 0 deletions selfservice/strategy/oidc/provider_jackson.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package oidc

import (
"context"
"strings"

"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"

"github.com/ory/x/urlx"
)

type ProviderJackson struct {
*ProviderGenericOIDC
}

func NewProviderJackson(
config *Configuration,
reg Dependencies,
) Provider {
return &ProviderJackson{
ProviderGenericOIDC: &ProviderGenericOIDC{
config: config,
reg: reg,
},
}
}

func (j *ProviderJackson) setProvider(ctx context.Context) {
if j.ProviderGenericOIDC.p == nil {
internalHost := strings.TrimSuffix(j.config.TokenURL, "/api/oauth/token")
config := oidc.ProviderConfig{
IssuerURL: j.config.IssuerURL,
AuthURL: j.config.AuthURL,
TokenURL: j.config.TokenURL,
DeviceAuthURL: "",
UserInfoURL: internalHost + "/api/oauth/userinfo",
JWKSURL: internalHost + "/oauth/jwks",
Algorithms: []string{"RS256"},
}
j.ProviderGenericOIDC.p = config.NewProvider(j.withHTTPClientContext(ctx))
}
}

func (j *ProviderJackson) OAuth2(ctx context.Context) (*oauth2.Config, error) {
j.setProvider(ctx)
endpoint := j.ProviderGenericOIDC.p.Endpoint()
config := j.oauth2ConfigFromEndpoint(ctx, endpoint)
config.RedirectURL = urlx.AppendPaths(
j.reg.Config().SAMLRedirectURIBase(ctx),
"/self-service/methods/saml/callback/"+j.config.ID).String()

return config, nil
}
36 changes: 36 additions & 0 deletions selfservice/strategy/oidc/provider_jackson_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package oidc_test

import (
"context"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/kratos/internal"
"github.com/ory/kratos/selfservice/strategy/oidc"
)

func TestProviderJackson(t *testing.T) {
_, reg := internal.NewVeryFastRegistryWithoutDB(t)

j := oidc.NewProviderJackson(&oidc.Configuration{
Provider: "jackson",
IssuerURL: "https://www.jackson.com/oauth",
AuthURL: "https://www.jackson.com/oauth/auth",
TokenURL: "https://www.jackson.com/api/oauth/token",
Mapper: "file://./stub/hydra.schema.json",
Scope: []string{"email", "profile"},
ID: "some-id",
}, reg)
assert.NotNil(t, j)

c, err := j.(oidc.OAuth2Provider).OAuth2(context.Background())
require.NoError(t, err)

assert.True(t, strings.HasSuffix(c.RedirectURL, "/self-service/methods/saml/callback/some-id"))
}
53 changes: 41 additions & 12 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import (
"github.com/ory/kratos/x"
"github.com/ory/x/decoderx"
"github.com/ory/x/jsonnetsecure"
"github.com/ory/x/jsonx"
"github.com/ory/x/otelx"
"github.com/ory/x/sqlxx"
"github.com/ory/x/stringsx"
Expand Down Expand Up @@ -119,9 +118,12 @@ func isForced(req interface{}) bool {
// Strategy implements selfservice.LoginStrategy, selfservice.RegistrationStrategy and selfservice.SettingsStrategy.
// It supports login, registration and settings via OpenID Providers.
type Strategy struct {
d Dependencies
validator *schema.Validator
dec *decoderx.HTTP
d Dependencies
validator *schema.Validator
dec *decoderx.HTTP
credType identity.CredentialsType
handleUnknownProviderError func(err error) error
handleMethodNotAllowedError func(err error) error
}

type AuthCodeContainer struct {
Expand Down Expand Up @@ -203,15 +205,42 @@ func (s *Strategy) redirectToGET(w http.ResponseWriter, r *http.Request, _ httpr
http.Redirect(w, r, dest.String(), http.StatusFound)
}

func NewStrategy(d any) *Strategy {
return &Strategy{
d: d.(Dependencies),
validator: schema.NewValidator(),
type NewStrategyOpt func(s *Strategy)

// ForCredentialType overrides the credentials type for this strategy.
func ForCredentialType(ct identity.CredentialsType) NewStrategyOpt {
return func(s *Strategy) { s.credType = ct }
}

// WithUnknownProviderHandler overrides the error returned when the provider
// cannot be found.
func WithUnknownProviderHandler(handler func(error) error) NewStrategyOpt {
return func(s *Strategy) { s.handleUnknownProviderError = handler }
}

// WithHandleMethodNotAllowedError overrides the error returned when method is
// not allowed.
func WithHandleMethodNotAllowedError(handler func(error) error) NewStrategyOpt {
return func(s *Strategy) { s.handleMethodNotAllowedError = handler }
}

func NewStrategy(d any, opts ...NewStrategyOpt) *Strategy {
s := &Strategy{
d: d.(Dependencies),
validator: schema.NewValidator(),
credType: identity.CredentialsTypeOIDC,
handleUnknownProviderError: func(err error) error { return err },
handleMethodNotAllowedError: func(err error) error { return err },
}
for _, opt := range opts {
opt(s)
}

return s
}

func (s *Strategy) ID() identity.CredentialsType {
return identity.CredentialsTypeOIDC
return s.credType
}

func (s *Strategy) validateFlow(ctx context.Context, r *http.Request, rid uuid.UUID) (flow.Flow, error) {
Expand Down Expand Up @@ -516,8 +545,8 @@ func (s *Strategy) Config(ctx context.Context) (*ConfigurationCollection, error)
var c ConfigurationCollection

conf := s.d.Config().SelfServiceStrategy(ctx, string(s.ID())).Config
if err := jsonx.
NewStrictDecoder(bytes.NewBuffer(conf)).
if err := json.
NewDecoder(bytes.NewBuffer(conf)).
Decode(&c); err != nil {
s.d.Logger().WithError(err).WithField("config", conf)
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to decode OpenID Connect Provider configuration: %s", err))
Expand All @@ -530,7 +559,7 @@ func (s *Strategy) provider(ctx context.Context, id string) (Provider, error) {
if c, err := s.Config(ctx); err != nil {
return nil, err
} else if provider, err := c.Provider(id, s.d); err != nil {
return nil, err
return nil, s.handleUnknownProviderError(err)
} else {
return provider, nil
}
Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (s *Strategy) processLogin(ctx context.Context, w http.ResponseWriter, r *h
ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.Strategy.processLogin")
defer otelx.End(span, &err)

i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject))
i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), identity.OIDCUniqueID(provider.Config().ID, claims.Subject))
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// If no account was found we're "manually" creating a new registration flow and redirecting the browser
Expand Down Expand Up @@ -218,7 +218,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}

if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil {
return nil, s.handleError(ctx, w, r, f, pid, nil, err)
return nil, s.handleError(ctx, w, r, f, pid, nil, s.handleMethodNotAllowedError(err))
}

provider, err := s.provider(ctx, pid)
Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
}

if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil {
return s.handleError(ctx, w, r, f, pid, nil, err)
return s.handleError(ctx, w, r, f, pid, nil, s.handleMethodNotAllowedError(err))
}

provider, err := s.provider(ctx, pid)
Expand Down Expand Up @@ -347,7 +347,7 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
}

i.SetCredentials(s.ID(), *creds)
if err := s.d.RegistrationExecutor().PostRegistrationHook(w, r, identity.CredentialsTypeOIDC, provider.Config().ID, provider.Config().OrganizationID, rf, i); err != nil {
if err := s.d.RegistrationExecutor().PostRegistrationHook(w, r, s.ID(), provider.Config().ID, provider.Config().OrganizationID, rf, i); err != nil {
return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err)
}

Expand Down
Loading