diff --git a/pkg/auth/providerrefresh/refresher.go b/pkg/auth/providerrefresh/refresher.go index 088e69200dd..4c5c8c94673 100644 --- a/pkg/auth/providerrefresh/refresher.go +++ b/pkg/auth/providerrefresh/refresher.go @@ -233,7 +233,7 @@ func (r *refresher) refreshAttributes(attribs *v3.UserAttribute) (*v3.UserAttrib // we no longer want to disable derived tokens, or remove their login tokens for this provider if err.Error() != "no access" { errorConfirmingLogins = true - logrus.Errorf("Error refreshing token principals, skipping: %v", err) + logrus.Warnf("Error refreshing token principals, skipping: %v", err) existingPrincipals := attribs.GroupPrincipals[providerName].Items if existingPrincipals != nil { newGroupPrincipals = existingPrincipals diff --git a/pkg/auth/providers/keycloakoidc/keycloak_provider.go b/pkg/auth/providers/keycloakoidc/keycloak_provider.go index cb84730b891..a476a165c8d 100644 --- a/pkg/auth/providers/keycloakoidc/keycloak_provider.go +++ b/pkg/auth/providers/keycloakoidc/keycloak_provider.go @@ -144,7 +144,7 @@ func (k *keyCloakOIDCProvider) GetPrincipal(principalID string, token v3.Token) principalType := parts[1] keyCloakClient, err := k.newClient(config, token) if err != nil { - logrus.Errorf("[keycloak oidc] GetPrincipal: error creating new http client: %v", err) + logrus.Warnf("[keycloak oidc] GetPrincipal: error creating new http client: %v", err) return v3.Principal{}, err } acct, err := keyCloakClient.getFromKeyCloakByID(externalID, principalType, config) diff --git a/pkg/auth/providers/mocks/tokenmanager.go b/pkg/auth/providers/mocks/tokenmanager.go new file mode 100644 index 00000000000..e8fb1c76647 --- /dev/null +++ b/pkg/auth/providers/mocks/tokenmanager.go @@ -0,0 +1,142 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: oidc_provider.go + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + time "time" + + types "github.com/rancher/norman/types" + v3 "github.com/rancher/rancher/pkg/generated/norman/management.cattle.io/v3" + gomock "go.uber.org/mock/gomock" +) + +// MocktokenManager is a mock of tokenManager interface. +type MocktokenManager struct { + ctrl *gomock.Controller + recorder *MocktokenManagerMockRecorder +} + +// MocktokenManagerMockRecorder is the mock recorder for MocktokenManager. +type MocktokenManagerMockRecorder struct { + mock *MocktokenManager +} + +// NewMocktokenManager creates a new mock instance. +func NewMocktokenManager(ctrl *gomock.Controller) *MocktokenManager { + mock := &MocktokenManager{ctrl: ctrl} + mock.recorder = &MocktokenManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MocktokenManager) EXPECT() *MocktokenManagerMockRecorder { + return m.recorder +} + +// CreateTokenAndSetCookie mocks base method. +func (m *MocktokenManager) CreateTokenAndSetCookie(userID string, userPrincipal v3.Principal, groupPrincipals []v3.Principal, providerToken string, ttl int, description string, request *types.APIContext) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateTokenAndSetCookie", userID, userPrincipal, groupPrincipals, providerToken, ttl, description, request) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateTokenAndSetCookie indicates an expected call of CreateTokenAndSetCookie. +func (mr *MocktokenManagerMockRecorder) CreateTokenAndSetCookie(userID, userPrincipal, groupPrincipals, providerToken, ttl, description, request interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTokenAndSetCookie", reflect.TypeOf((*MocktokenManager)(nil).CreateTokenAndSetCookie), userID, userPrincipal, groupPrincipals, providerToken, ttl, description, request) +} + +// GetGroupsForTokenAuthProvider mocks base method. +func (m *MocktokenManager) GetGroupsForTokenAuthProvider(token *v3.Token) []v3.Principal { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupsForTokenAuthProvider", token) + ret0, _ := ret[0].([]v3.Principal) + return ret0 +} + +// GetGroupsForTokenAuthProvider indicates an expected call of GetGroupsForTokenAuthProvider. +func (mr *MocktokenManagerMockRecorder) GetGroupsForTokenAuthProvider(token interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsForTokenAuthProvider", reflect.TypeOf((*MocktokenManager)(nil).GetGroupsForTokenAuthProvider), token) +} + +// GetSecret mocks base method. +func (m *MocktokenManager) GetSecret(userID, provider string, fallbackTokens []*v3.Token) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSecret", userID, provider, fallbackTokens) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSecret indicates an expected call of GetSecret. +func (mr *MocktokenManagerMockRecorder) GetSecret(userID, provider, fallbackTokens interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSecret", reflect.TypeOf((*MocktokenManager)(nil).GetSecret), userID, provider, fallbackTokens) +} + +// IsMemberOf mocks base method. +func (m *MocktokenManager) IsMemberOf(token v3.Token, group v3.Principal) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsMemberOf", token, group) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsMemberOf indicates an expected call of IsMemberOf. +func (mr *MocktokenManagerMockRecorder) IsMemberOf(token, group interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsMemberOf", reflect.TypeOf((*MocktokenManager)(nil).IsMemberOf), token, group) +} + +// UpdateSecret mocks base method. +func (m *MocktokenManager) UpdateSecret(userID, provider, secret string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateSecret", userID, provider, secret) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateSecret indicates an expected call of UpdateSecret. +func (mr *MocktokenManagerMockRecorder) UpdateSecret(userID, provider, secret interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSecret", reflect.TypeOf((*MocktokenManager)(nil).UpdateSecret), userID, provider, secret) +} + +// UpdateToken mocks base method. +func (m *MocktokenManager) UpdateToken(token *v3.Token) (*v3.Token, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateToken", token) + ret0, _ := ret[0].(*v3.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateToken indicates an expected call of UpdateToken. +func (mr *MocktokenManagerMockRecorder) UpdateToken(token interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MocktokenManager)(nil).UpdateToken), token) +} + +// UserAttributeCreateOrUpdate mocks base method. +func (m *MocktokenManager) UserAttributeCreateOrUpdate(userID, provider string, groupPrincipals []v3.Principal, userExtraInfo map[string][]string, loginTime ...time.Time) error { + m.ctrl.T.Helper() + varargs := []interface{}{userID, provider, groupPrincipals, userExtraInfo} + for _, a := range loginTime { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UserAttributeCreateOrUpdate", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// UserAttributeCreateOrUpdate indicates an expected call of UserAttributeCreateOrUpdate. +func (mr *MocktokenManagerMockRecorder) UserAttributeCreateOrUpdate(userID, provider, groupPrincipals, userExtraInfo interface{}, loginTime ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{userID, provider, groupPrincipals, userExtraInfo}, loginTime...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserAttributeCreateOrUpdate", reflect.TypeOf((*MocktokenManager)(nil).UserAttributeCreateOrUpdate), varargs...) +} diff --git a/pkg/auth/providers/oidc/oidc_provider.go b/pkg/auth/providers/oidc/oidc_provider.go index 0eeb528c757..644539fb0a3 100644 --- a/pkg/auth/providers/oidc/oidc_provider.go +++ b/pkg/auth/providers/oidc/oidc_provider.go @@ -7,6 +7,7 @@ import ( "reflect" "slices" "strings" + "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt" @@ -35,6 +36,14 @@ const ( GroupType = "group" ) +type tokenManager interface { + IsMemberOf(token v3.Token, group v3.Principal) bool + UpdateSecret(userID, provider, secret string) error + UserAttributeCreateOrUpdate(userID, provider string, groupPrincipals []v3.Principal, userExtraInfo map[string][]string, loginTime ...time.Time) error + CreateTokenAndSetCookie(userID string, userPrincipal v3.Principal, groupPrincipals []v3.Principal, providerToken string, ttl int, description string, request *types.APIContext) error + GetSecret(userID string, provider string, fallbackTokens []*v3.Token) (string, error) +} + type OpenIDCProvider struct { Name string Type string @@ -42,7 +51,7 @@ type OpenIDCProvider struct { AuthConfigs v3.AuthConfigInterface Secrets wcorev1.SecretController UserMGR user.Manager - TokenMGR *tokens.Manager + TokenMGR tokenManager } type ClaimInfo struct { @@ -108,7 +117,7 @@ func (o *OpenIDCProvider) LoginUser(ctx context.Context, oauthLoginInfo *v32.OID return userPrincipal, nil, "", userClaimInfo, err } } - userInfo, oauth2Token, err := o.getUserInfo(&ctx, config, oauthLoginInfo.Code, &userClaimInfo, "") + userInfo, oauth2Token, err := o.getUserInfoFromAuthCode(&ctx, config, oauthLoginInfo.Code, &userClaimInfo, "") if err != nil { return userPrincipal, groupPrincipals, "", userClaimInfo, err } @@ -208,7 +217,6 @@ func (o *OpenIDCProvider) getRedirectURL(config map[string]interface{}) string { func (o *OpenIDCProvider) RefetchGroupPrincipals(principalID string, secret string) ([]v3.Principal, error) { var groupPrincipals []v3.Principal - var claimInfo ClaimInfo config, err := o.GetOIDCConfig() if err != nil { @@ -221,12 +229,16 @@ func (o *OpenIDCProvider) RefetchGroupPrincipals(principalID string, secret stri logrus.Errorf("[generic oidc] refetchGroupPrincipals: error getting user by principalID: %v", err) return groupPrincipals, err } - //do not need userInfo or oauth2Token since we are only processing groups - _, _, err = o.getUserInfo(&o.CTX, config, secret, &claimInfo, user.Name) + var oauthToken oauth2.Token + if err := json.Unmarshal([]byte(secret), &oauthToken); err != nil { + return nil, err + } + + claimInfo, err := o.getClaimInfoFromToken(o.CTX, config, &oauthToken, user.Name) if err != nil { return groupPrincipals, err } - return o.getGroupsFromClaimInfo(claimInfo), nil + return o.getGroupsFromClaimInfo(*claimInfo), nil } func (o *OpenIDCProvider) CanAccessWithGroupProviders(userPrincipalID string, groupPrincipals []v3.Principal) (bool, error) { @@ -375,7 +387,7 @@ func (o *OpenIDCProvider) GetUserExtraAttributes(userPrincipal v3.Principal) map return extras } -func (o *OpenIDCProvider) getUserInfo(ctx *context.Context, config *v32.OIDCConfig, authCode string, claimInfo *ClaimInfo, userName string) (*oidc.UserInfo, *oauth2.Token, error) { +func (o *OpenIDCProvider) getUserInfoFromAuthCode(ctx *context.Context, config *v32.OIDCConfig, authCode string, claimInfo *ClaimInfo, userName string) (*oidc.UserInfo, *oauth2.Token, error) { var userInfo *oidc.UserInfo var oauth2Token *oauth2.Token var err error @@ -414,30 +426,25 @@ func (o *OpenIDCProvider) getUserInfo(ctx *context.Context, config *v32.OIDCConf // Valid will return false if access token is expired if !oauth2Token.Valid() { - // since token is not valid, the TokenSource func will attempt to refresh the access token - // if the refresh token has not expired - logrus.Debugf("[generic oidc] getUserInfo: attempting to refresh access token") + return userInfo, oauth2Token, fmt.Errorf("not valid token: %w", err) } - reusedToken, err := oauth2.ReuseTokenSource(oauth2Token, oauthConfig.TokenSource(updatedContext, oauth2Token)).Token() - if err != nil { - return userInfo, oauth2Token, err - } - if !reflect.DeepEqual(oauth2Token, reusedToken) { - o.UpdateToken(reusedToken, userName) + + if err := o.UpdateToken(oauth2Token, userName); err != nil { + return nil, nil, err } if config.AcrValue != "" { acrValue, err := parseACRFromAccessToken(oauth2Token.AccessToken) if err != nil { - return userInfo, oauth2Token, fmt.Errorf("failed to parse ACR from access token: %w", err) + return userInfo, oauth2Token, err } if !isValidACR(acrValue, config.AcrValue) { return userInfo, oauth2Token, errors.New("failed to validate ACR") } } - logrus.Debugf("[generic oidc] getUserInfo: getting user info") - userInfo, err = provider.UserInfo(updatedContext, oauthConfig.TokenSource(updatedContext, reusedToken)) + logrus.Debugf("[generic oidc] getUserInfo: getting user info for user %s", userName) + userInfo, err = provider.UserInfo(updatedContext, oauthConfig.TokenSource(updatedContext, oauth2Token)) if err != nil { return userInfo, oauth2Token, err } @@ -448,6 +455,71 @@ func (o *OpenIDCProvider) getUserInfo(ctx *context.Context, config *v32.OIDCConf return userInfo, oauth2Token, nil } +func (o *OpenIDCProvider) getClaimInfoFromToken(ctx context.Context, config *v32.OIDCConfig, token *oauth2.Token, userName string) (*ClaimInfo, error) { + var userInfo *oidc.UserInfo + var err error + var claimInfo *ClaimInfo + + updatedContext, err := AddCertKeyToContext(ctx, config.Certificate, config.PrivateKey) + if err != nil { + return nil, err + } + + provider, err := o.getOIDCProvider(updatedContext, config) + if err != nil { + return nil, err + } + oauthConfig := ConfigToOauthConfig(provider.Endpoint(), config) + var verifier = provider.Verifier(&oidc.Config{ClientID: config.ClientID}) + + // Valid will return false if access token is expired + if !token.Valid() { + // since token is not valid, the TokenSource func will attempt to refresh the access token + // if the refresh token has not expired + logrus.Debugf("[generic oidc] getUserInfo: attempting to refresh access token") + reusedToken, err := oauth2.ReuseTokenSource(token, oauthConfig.TokenSource(updatedContext, token)).Token() + if err != nil { + return nil, err + } + if !reflect.DeepEqual(token, reusedToken) { + err := o.UpdateToken(reusedToken, userName) + if err != nil { + return nil, fmt.Errorf("failed to update token: %w", err) + } + } + token = reusedToken + } + + idToken, err := verifier.Verify(updatedContext, token.AccessToken) + if err != nil { + return nil, fmt.Errorf("failed to verify ID token: %w", err) + } + if err := idToken.Claims(&claimInfo); err != nil { + return nil, fmt.Errorf("failed to parse claims: %w", err) + } + + if config.AcrValue != "" { + acrValue, err := parseACRFromAccessToken(token.AccessToken) + if err != nil { + return nil, err + } + if !isValidACR(acrValue, config.AcrValue) { + return nil, errors.New("failed due to invalid ACR") + } + } + + logrus.Debugf("[generic oidc] getUserInfo: getting user info for user %s", userName) + userInfo, err = provider.UserInfo(updatedContext, oauthConfig.TokenSource(updatedContext, token)) + if err != nil { + return nil, err + } + if err := userInfo.Claims(&claimInfo); err != nil { + return nil, err + } + + return claimInfo, nil +} + func ConfigToOauthConfig(endpoint oauth2.Endpoint, config *v32.OIDCConfig) oauth2.Config { var finalScopes []string hasOIDCScope := strings.Contains(config.Scopes, oidc.ScopeOpenID) @@ -557,7 +629,7 @@ func isValidACR(claimACR string, configuredACR string) bool { } if claimACR != configuredACR { - logrus.Infof("acr value in token does not match configured acr value") + logrus.Infof("[generic oidc] acr value in token does not match configured acr value") return false } return true @@ -568,15 +640,15 @@ func parseACRFromAccessToken(accessToken string) (string, error) { // we already validated the incoming token token, _, err := parser.ParseUnverified(accessToken, jwt.MapClaims{}) if err != nil { - return "", fmt.Errorf("failed to parse token: %w", err) + return "", fmt.Errorf("failed to parse JWT token: %w", err) } claims, ok := token.Claims.(jwt.MapClaims) if !ok { - return "", errors.New("invalid access token jwt.MapClaims format") + return "", errors.New("failed to parse claims in JWT token: invalid jwt.MapClaims format") } acrValue, found := claims["acr"].(string) if !found { - return "", fmt.Errorf("acr claim invalid or not found in token: (acr=%v)", claims["acr"]) + return "", fmt.Errorf("ACR claim invalid or not found in token: (acr=%v)", claims["acr"]) } return acrValue, nil } diff --git a/pkg/auth/providers/oidc/oidc_provider_test.go b/pkg/auth/providers/oidc/oidc_provider_test.go index b1b67a22b7e..5cd628e5c57 100644 --- a/pkg/auth/providers/oidc/oidc_provider_test.go +++ b/pkg/auth/providers/oidc/oidc_provider_test.go @@ -1,12 +1,24 @@ package oidc import ( + "context" + "crypto/rand" + "crypto/rsa" "encoding/base64" + "encoding/json" "fmt" + "net" + "net/http" + "strconv" "testing" + "time" "github.com/golang-jwt/jwt" + v32 "github.com/rancher/rancher/pkg/apis/management.cattle.io/v3" + "github.com/rancher/rancher/pkg/auth/providers/mocks" "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "golang.org/x/oauth2" ) func Test_validateACR(t *testing.T) { @@ -91,12 +103,493 @@ func TestParseACRFromAccessToken(t *testing.T) { } } -// generateAccessToken generates an access token with the specified acr. -func generateAccessToken(acr string) string { - claims := jwt.MapClaims{ - "acr": acr, +func TestGetUserInfoFromAuthCode(t *testing.T) { + const ( + providerName = "keycloak" + userId = "user" + ) + ctrl := gomock.NewController(t) + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + assert.NoError(t, err) } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, _ := token.SignedString([]byte("test_secret_key")) - return tokenString + tests := map[string]struct { + config func(string) *v32.OIDCConfig + authCode string + claimInfo *ClaimInfo + tokenManagerMock func(token *Token) tokenManager + oidcProviderResponses func(string) oidcResponses + expectedUserInfoSubject string + expectedUserInfoClaimInfo ClaimInfo + expectedErrorMessage string + }{ + "token is updated and userInfo returned": { + config: func(port string) *v32.OIDCConfig { + return newOIDCContext(port) + }, + tokenManagerMock: func(token *Token) tokenManager { + mock := mocks.NewMocktokenManager(ctrl) + mock.EXPECT().UpdateSecret(userId, providerName, EqToken(token.IDToken)) + + return mock + }, + oidcProviderResponses: func(port string) oidcResponses { + return newOIDCResponses(privateKey, port) + }, + expectedUserInfoSubject: "a8d0d2c4-6543-4546-8f1a-73e1d7dffcbd", + expectedUserInfoClaimInfo: ClaimInfo{ + Subject: "a8d0d2c4-6543-4546-8f1a-73e1d7dffcbd", + PreferredUsername: "admin", + EmailVerified: true, + Groups: []string{"admingroup"}, + FullGroupPath: []string{"/admingroup"}, + }, + }, + "error - invalid certificate": { + config: func(port string) *v32.OIDCConfig { + return &v32.OIDCConfig{ + Issuer: "http://localhost:" + port, + ClientID: "test", + JWKSUrl: "http://localhost:" + port + "/.well-known/jwks.json", + Certificate: "invalid", + PrivateKey: "invalid", + } + }, + tokenManagerMock: func(token *Token) tokenManager { + return mocks.NewMocktokenManager(ctrl) + }, + oidcProviderResponses: func(port string) oidcResponses { + return newOIDCResponses(privateKey, port) + }, + expectedErrorMessage: "could not parse cert/key pair: tls: failed to find any PEM data in certificate input", + }, + "error - invalid token from server": { + config: func(port string) *v32.OIDCConfig { + return newOIDCContext(port) + }, + tokenManagerMock: func(token *Token) tokenManager { + return mocks.NewMocktokenManager(ctrl) + }, + oidcProviderResponses: func(port string) oidcResponses { + resp := newOIDCResponses(privateKey, port) + resp.token.IDToken = "invalid" + + return resp + }, + expectedErrorMessage: "oidc: malformed jwt", + }, + "error - invalid user response": { + config: func(port string) *v32.OIDCConfig { + return newOIDCContext(port) + }, + tokenManagerMock: func(token *Token) tokenManager { + mock := mocks.NewMocktokenManager(ctrl) + mock.EXPECT().UpdateSecret(userId, providerName, EqToken(token.IDToken)) + + return mock + }, + oidcProviderResponses: func(port string) oidcResponses { + resp := newOIDCResponses(privateKey, port) + resp.user = "invalid" + + return resp + }, + expectedErrorMessage: "oidc: failed to decode userinfo", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", ":0") // choose any available port + assert.NoError(t, err) + port := strconv.Itoa(listener.Addr().(*net.TCPAddr).Port) + oidcResp := test.oidcProviderResponses(port) + server := mockOIDCServer(listener, oidcResp) + defer server.Shutdown(context.TODO()) + o := OpenIDCProvider{ + Name: providerName, + TokenMGR: test.tokenManagerMock(oidcResp.token), + } + ctx := context.TODO() + + userInfo, token, err := o.getUserInfoFromAuthCode(&ctx, test.config(port), test.authCode, test.claimInfo, userId) + + if test.expectedErrorMessage != "" { + assert.ErrorContains(t, err, test.expectedErrorMessage) + } else { + assert.NoError(t, err) + claims := ClaimInfo{} + assert.NoError(t, userInfo.Claims(&claims)) + assert.Equal(t, test.expectedUserInfoSubject, userInfo.Subject) + assert.Equal(t, test.expectedUserInfoClaimInfo, claims) + assert.Equal(t, oidcResp.token.AccessToken, token.AccessToken) //token should be the same as the one returned by the mock oidc server. + } + }) + } +} + +func TestGetClaimInfoFromToken(t *testing.T) { + const ( + providerName = "keycloak" + userId = "user" + ) + + ctrl := gomock.NewController(t) + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + assert.NoError(t, err) + } + + tests := map[string]struct { + config func(string) *v32.OIDCConfig + storedToken func(string) *oauth2.Token + tokenManagerMock func(token *Token) tokenManager + oidcProviderResponses func(string) oidcResponses + expectedClaimInfo *ClaimInfo + expectedErrorMessage string + }{ + "get claims with valid token": { + config: func(port string) *v32.OIDCConfig { + return newOIDCContext(port) + }, + storedToken: func(port string) *oauth2.Token { + token := jwt.New(jwt.SigningMethodRS256) + token.Claims = jwt.StandardClaims{ + Audience: "test", + ExpiresAt: time.Now().Add(5 * time.Minute).Unix(), // expires in the future + Issuer: "http://localhost:" + port, + } + tokenStr, err := token.SignedString(privateKey) + assert.NoError(t, err) + + return &oauth2.Token{ + AccessToken: tokenStr, + Expiry: time.Now().Add(5 * time.Minute), // expires in the future + } + }, + oidcProviderResponses: func(port string) oidcResponses { + return newOIDCResponses(privateKey, port) + }, + tokenManagerMock: func(_ *Token) tokenManager { + return mocks.NewMocktokenManager(ctrl) + }, + expectedClaimInfo: &ClaimInfo{ + Subject: "a8d0d2c4-6543-4546-8f1a-73e1d7dffcbd", + PreferredUsername: "admin", + EmailVerified: true, + Groups: []string{"admingroup"}, + FullGroupPath: []string{"/admingroup"}, + }, + }, + "token is refreshed and updated when expired": { + config: func(port string) *v32.OIDCConfig { + return newOIDCContext(port) + }, + oidcProviderResponses: func(port string) oidcResponses { + return newOIDCResponses(privateKey, port) + }, + storedToken: func(port string) *oauth2.Token { + token := jwt.New(jwt.SigningMethodRS256) + token.Claims = jwt.StandardClaims{ + Audience: "test", + ExpiresAt: time.Unix(0, 0).Unix(), // has expired + Issuer: "http://localhost:" + port, + } + tokenStr, err := token.SignedString(privateKey) + assert.NoError(t, err) + refreshToken := jwt.New(jwt.SigningMethodRS256) + refreshToken.Claims = jwt.StandardClaims{ + Audience: "test", + ExpiresAt: time.Now().Add(5 * time.Minute).Unix(), // expires in the future + Issuer: "http://localhost:" + port, + } + refreshTokenStr, err := refreshToken.SignedString(privateKey) + assert.NoError(t, err) + + return &oauth2.Token{ + AccessToken: tokenStr, + Expiry: time.Unix(0, 0), // has expired + RefreshToken: refreshTokenStr, + } + }, + expectedClaimInfo: &ClaimInfo{ + Subject: "a8d0d2c4-6543-4546-8f1a-73e1d7dffcbd", + PreferredUsername: "admin", + EmailVerified: true, + Groups: []string{"admingroup"}, + FullGroupPath: []string{"/admingroup"}, + }, + tokenManagerMock: func(token *Token) tokenManager { + mock := mocks.NewMocktokenManager(ctrl) + mock.EXPECT().UpdateSecret(userId, providerName, EqToken(token.RefreshToken)) + + return mock + }, + }, + "error - invalid certificate": { + config: func(port string) *v32.OIDCConfig { + return &v32.OIDCConfig{ + Issuer: "http://localhost:" + port, + ClientID: "test", + JWKSUrl: "http://localhost:" + port + "/.well-known/jwks.json", + Certificate: "invalid", + PrivateKey: "invalid", + } + }, + storedToken: func(port string) *oauth2.Token { + return &oauth2.Token{} + }, + oidcProviderResponses: func(port string) oidcResponses { + return newOIDCResponses(privateKey, port) + }, + tokenManagerMock: func(_ *Token) tokenManager { + return mocks.NewMocktokenManager(ctrl) + }, + expectedClaimInfo: nil, + expectedErrorMessage: "could not parse cert/key pair: tls: failed to find any PEM data in certificate input", + }, + "error - invalid token": { + config: func(port string) *v32.OIDCConfig { + return newOIDCContext(port) + }, + storedToken: func(port string) *oauth2.Token { + return &oauth2.Token{ + AccessToken: "invalid", + } + }, + oidcProviderResponses: func(port string) oidcResponses { + return newOIDCResponses(privateKey, port) + }, + tokenManagerMock: func(_ *Token) tokenManager { + return mocks.NewMocktokenManager(ctrl) + }, + expectedClaimInfo: nil, + expectedErrorMessage: "oidc: malformed jwt", + }, + "error - invalid user response": { + config: func(port string) *v32.OIDCConfig { + return newOIDCContext(port) + }, + storedToken: func(port string) *oauth2.Token { + token := jwt.New(jwt.SigningMethodRS256) + token.Claims = jwt.StandardClaims{ + Audience: "test", + ExpiresAt: time.Now().Add(5 * time.Minute).Unix(), // expires in the future + Issuer: "http://localhost:" + port, + } + tokenStr, err := token.SignedString(privateKey) + assert.NoError(t, err) + + return &oauth2.Token{ + AccessToken: tokenStr, + Expiry: time.Now().Add(5 * time.Minute), // expires in the future + } + }, + oidcProviderResponses: func(port string) oidcResponses { + resp := newOIDCResponses(privateKey, port) + resp.user = "invalid" + + return resp + }, + tokenManagerMock: func(_ *Token) tokenManager { + return mocks.NewMocktokenManager(ctrl) + }, + expectedClaimInfo: nil, + expectedErrorMessage: "oidc: failed to decode userinfo", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", ":0") // choose any available port + assert.NoError(t, err) + port := strconv.Itoa(listener.Addr().(*net.TCPAddr).Port) + oidcResp := test.oidcProviderResponses(port) + server := mockOIDCServer(listener, oidcResp) + assert.NoError(t, err) + defer server.Shutdown(context.TODO()) + o := OpenIDCProvider{ + Name: providerName, + TokenMGR: test.tokenManagerMock(oidcResp.token), + } + + claimsInfo, err := o.getClaimInfoFromToken(context.TODO(), test.config(port), test.storedToken(port), userId) + + assert.Equal(t, test.expectedClaimInfo, claimsInfo) + if test.expectedErrorMessage == "" { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, test.expectedErrorMessage) + } + }) + } +} + +// mockOIDCServer creates an http server that mocks an OIDC provider. Responses are passed as a parameter. +func mockOIDCServer(listener net.Listener, resp oidcResponses) *http.Server { + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp.config) + }) + mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp.jwks) + }) + mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(resp.user)) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp.token) + }) + + server := &http.Server{ + Handler: mux, + } + + go func() { + _ = server.Serve(listener) + }() + + return server +} + +type oidcResponses struct { + user string + config providerJSON + jwks jsonWebKeySet + token *Token +} + +type Token struct { + oauth2.Token + IDToken string `json:"id_token"` +} + +func newOIDCResponses(privateKey *rsa.PrivateKey, port string) oidcResponses { + jwtToken := jwt.New(jwt.SigningMethodRS256) + jwtToken.Claims = jwt.StandardClaims{ + Audience: "test", + ExpiresAt: time.Now().Add(5 * time.Minute).Unix(), // has expired + Issuer: "http://localhost:" + port, + } + jwtSrt, _ := jwtToken.SignedString(privateKey) + // token returned from the /token endpoint + token := &Token{ + Token: oauth2.Token{ + AccessToken: jwtSrt, + Expiry: time.Now().Add(5 * time.Minute), // expires in the future + RefreshToken: jwtSrt, + }, + IDToken: jwtSrt, + } + + return oidcResponses{ + user: `{ + "sub": "a8d0d2c4-6543-4546-8f1a-73e1d7dffcbd", + "email_verified": true, + "groups": [ + "admingroup" + ], + "full_group_path": [ + "/admingroup" + ], + "preferred_username": "admin" + }`, + config: providerJSON{ + Issuer: "http://localhost:" + port, + UserInfoURL: "http://localhost:" + port + "/user", + JWKSURL: "http://localhost:" + port + "/.well-known/jwks.json", + AuthURL: "http://localhost:" + port + "/auth", + TokenURL: "http://localhost:" + port + "/token", + }, + token: token, + jwks: jsonWebKeySet{ + Keys: []jsonWebKey{ + { + Kty: "RSA", + Kid: "example-key-id", + Use: "sig", + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(privateKey.PublicKey.E)), + }, + }, + }, + } +} + +func newOIDCContext(port string) *v32.OIDCConfig { + return &v32.OIDCConfig{ + Issuer: "http://localhost:" + port, + ClientID: "test", + JWKSUrl: "http://localhost:" + port + "/.well-known/jwks.json", + AuthEndpoint: "http://localhost:" + port + "/auth", + TokenEndpoint: "http://localhost:" + port + "/token", + UserInfoEndpoint: "http://localhost:" + port + "/user", + } +} + +type jsonWebKeySet struct { + Keys []jsonWebKey `json:"keys"` +} + +type jsonWebKey struct { + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + Alg string `json:"alg"` + N string `json:"n"` + E string `json:"e"` +} + +// Helper function to convert a big.Int (exponent) to []byte +func bigIntToBytes(i int) []byte { + var b [4]byte + b[0] = byte(i >> 24) + b[1] = byte(i >> 16) + b[2] = byte(i >> 8) + b[3] = byte(i) + return b[:] +} + +type providerJSON struct { + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + DeviceAuthURL string `json:"device_authorization_endpoint"` + JWKSURL string `json:"jwks_uri"` + UserInfoURL string `json:"userinfo_endpoint"` + Algorithms []string `json:"id_token_signing_alg_values_supported"` +} + +// expiryIn is calculated inside the oauth2 library using time.Now, so we just compare the token is equal +type tokenMatcher struct { + accessToken string +} + +func (m tokenMatcher) Matches(i interface{}) bool { + tokenStr, ok := i.(string) + if !ok { + return false + } + token := oauth2.Token{} + err := json.Unmarshal([]byte(tokenStr), &token) + if err != nil { + return false + } + + return token.AccessToken == m.accessToken +} + +func (m tokenMatcher) String() string { + return fmt.Sprintf("is equal to %s", m.accessToken) +} + +func EqToken(accessToken string) gomock.Matcher { + return tokenMatcher{accessToken} } diff --git a/pkg/controllers/management/auth/user_attribute_handler.go b/pkg/controllers/management/auth/user_attribute_handler.go index 0de902b9e6f..bb26971895e 100644 --- a/pkg/controllers/management/auth/user_attribute_handler.go +++ b/pkg/controllers/management/auth/user_attribute_handler.go @@ -4,6 +4,7 @@ package auth import ( "context" + "errors" "fmt" "github.com/rancher/rancher/pkg/auth/providerrefresh" @@ -12,6 +13,7 @@ import ( v3 "github.com/rancher/rancher/pkg/generated/norman/management.cattle.io/v3" "github.com/rancher/rancher/pkg/types/config" "github.com/sirupsen/logrus" + "golang.org/x/oauth2" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -76,6 +78,14 @@ func (c *UserAttributeController) sync(key string, attribs *v3.UserAttribute) (r attribs, err = c.providerRefresh(attribs) if err != nil { + var retrieveErr *oauth2.RetrieveError + // Stop retrying if the token has expired. + if errors.As(err, &retrieveErr) { + if retrieveErr.ErrorCode == "invalid_grant" { + logrus.Warnf("Token has expired. UserAttributes won't be refreshed until the user %s logs in. Error message: %s", name, err) + return nil, nil + } + } return nil, fmt.Errorf("error refreshing user attribute %s: %w", name, err) }