Skip to content

Commit

Permalink
Refresh tokens in Keycloak (rancher#48102)
Browse files Browse the repository at this point in the history
If the access token has expired we try to fetch a new token with the refresh token. If the refresh token has expired we show a warning message and stop retrying.
  • Loading branch information
raulcabello authored Dec 4, 2024
1 parent efd9ca4 commit da0cf1d
Show file tree
Hide file tree
Showing 6 changed files with 749 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pkg/auth/providerrefresh/refresher.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func (r *refresher) refreshAttributes(attribs *v3.UserAttribute) (*v3.UserAttrib
// we no longer want to disable derived tokens, or remove their login tokens for this provider
if err.Error() != "no access" {
errorConfirmingLogins = true
logrus.Errorf("Error refreshing token principals, skipping: %v", err)
logrus.Warnf("Error refreshing token principals, skipping: %v", err)
existingPrincipals := attribs.GroupPrincipals[providerName].Items
if existingPrincipals != nil {
newGroupPrincipals = existingPrincipals
Expand Down
2 changes: 1 addition & 1 deletion pkg/auth/providers/keycloakoidc/keycloak_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (k *keyCloakOIDCProvider) GetPrincipal(principalID string, token v3.Token)
principalType := parts[1]
keyCloakClient, err := k.newClient(config, token)
if err != nil {
logrus.Errorf("[keycloak oidc] GetPrincipal: error creating new http client: %v", err)
logrus.Warnf("[keycloak oidc] GetPrincipal: error creating new http client: %v", err)
return v3.Principal{}, err
}
acct, err := keyCloakClient.getFromKeyCloakByID(externalID, principalType, config)
Expand Down
142 changes: 142 additions & 0 deletions pkg/auth/providers/mocks/tokenmanager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

118 changes: 95 additions & 23 deletions pkg/auth/providers/oidc/oidc_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"reflect"
"slices"
"strings"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
Expand Down Expand Up @@ -35,14 +36,22 @@ const (
GroupType = "group"
)

type tokenManager interface {
IsMemberOf(token v3.Token, group v3.Principal) bool
UpdateSecret(userID, provider, secret string) error
UserAttributeCreateOrUpdate(userID, provider string, groupPrincipals []v3.Principal, userExtraInfo map[string][]string, loginTime ...time.Time) error
CreateTokenAndSetCookie(userID string, userPrincipal v3.Principal, groupPrincipals []v3.Principal, providerToken string, ttl int, description string, request *types.APIContext) error
GetSecret(userID string, provider string, fallbackTokens []*v3.Token) (string, error)
}

type OpenIDCProvider struct {
Name string
Type string
CTX context.Context
AuthConfigs v3.AuthConfigInterface
Secrets wcorev1.SecretController
UserMGR user.Manager
TokenMGR *tokens.Manager
TokenMGR tokenManager
}

type ClaimInfo struct {
Expand Down Expand Up @@ -108,7 +117,7 @@ func (o *OpenIDCProvider) LoginUser(ctx context.Context, oauthLoginInfo *v32.OID
return userPrincipal, nil, "", userClaimInfo, err
}
}
userInfo, oauth2Token, err := o.getUserInfo(&ctx, config, oauthLoginInfo.Code, &userClaimInfo, "")
userInfo, oauth2Token, err := o.getUserInfoFromAuthCode(&ctx, config, oauthLoginInfo.Code, &userClaimInfo, "")
if err != nil {
return userPrincipal, groupPrincipals, "", userClaimInfo, err
}
Expand Down Expand Up @@ -208,7 +217,6 @@ func (o *OpenIDCProvider) getRedirectURL(config map[string]interface{}) string {

func (o *OpenIDCProvider) RefetchGroupPrincipals(principalID string, secret string) ([]v3.Principal, error) {
var groupPrincipals []v3.Principal
var claimInfo ClaimInfo

config, err := o.GetOIDCConfig()
if err != nil {
Expand All @@ -221,12 +229,16 @@ func (o *OpenIDCProvider) RefetchGroupPrincipals(principalID string, secret stri
logrus.Errorf("[generic oidc] refetchGroupPrincipals: error getting user by principalID: %v", err)
return groupPrincipals, err
}
//do not need userInfo or oauth2Token since we are only processing groups
_, _, err = o.getUserInfo(&o.CTX, config, secret, &claimInfo, user.Name)
var oauthToken oauth2.Token
if err := json.Unmarshal([]byte(secret), &oauthToken); err != nil {
return nil, err
}

claimInfo, err := o.getClaimInfoFromToken(o.CTX, config, &oauthToken, user.Name)
if err != nil {
return groupPrincipals, err
}
return o.getGroupsFromClaimInfo(claimInfo), nil
return o.getGroupsFromClaimInfo(*claimInfo), nil
}

func (o *OpenIDCProvider) CanAccessWithGroupProviders(userPrincipalID string, groupPrincipals []v3.Principal) (bool, error) {
Expand Down Expand Up @@ -375,7 +387,7 @@ func (o *OpenIDCProvider) GetUserExtraAttributes(userPrincipal v3.Principal) map
return extras
}

func (o *OpenIDCProvider) getUserInfo(ctx *context.Context, config *v32.OIDCConfig, authCode string, claimInfo *ClaimInfo, userName string) (*oidc.UserInfo, *oauth2.Token, error) {
func (o *OpenIDCProvider) getUserInfoFromAuthCode(ctx *context.Context, config *v32.OIDCConfig, authCode string, claimInfo *ClaimInfo, userName string) (*oidc.UserInfo, *oauth2.Token, error) {
var userInfo *oidc.UserInfo
var oauth2Token *oauth2.Token
var err error
Expand Down Expand Up @@ -414,30 +426,25 @@ func (o *OpenIDCProvider) getUserInfo(ctx *context.Context, config *v32.OIDCConf

// Valid will return false if access token is expired
if !oauth2Token.Valid() {
// since token is not valid, the TokenSource func will attempt to refresh the access token
// if the refresh token has not expired
logrus.Debugf("[generic oidc] getUserInfo: attempting to refresh access token")
return userInfo, oauth2Token, fmt.Errorf("not valid token: %w", err)
}
reusedToken, err := oauth2.ReuseTokenSource(oauth2Token, oauthConfig.TokenSource(updatedContext, oauth2Token)).Token()
if err != nil {
return userInfo, oauth2Token, err
}
if !reflect.DeepEqual(oauth2Token, reusedToken) {
o.UpdateToken(reusedToken, userName)

if err := o.UpdateToken(oauth2Token, userName); err != nil {
return nil, nil, err
}

if config.AcrValue != "" {
acrValue, err := parseACRFromAccessToken(oauth2Token.AccessToken)
if err != nil {
return userInfo, oauth2Token, fmt.Errorf("failed to parse ACR from access token: %w", err)
return userInfo, oauth2Token, err
}
if !isValidACR(acrValue, config.AcrValue) {
return userInfo, oauth2Token, errors.New("failed to validate ACR")
}
}

logrus.Debugf("[generic oidc] getUserInfo: getting user info")
userInfo, err = provider.UserInfo(updatedContext, oauthConfig.TokenSource(updatedContext, reusedToken))
logrus.Debugf("[generic oidc] getUserInfo: getting user info for user %s", userName)
userInfo, err = provider.UserInfo(updatedContext, oauthConfig.TokenSource(updatedContext, oauth2Token))
if err != nil {
return userInfo, oauth2Token, err
}
Expand All @@ -448,6 +455,71 @@ func (o *OpenIDCProvider) getUserInfo(ctx *context.Context, config *v32.OIDCConf
return userInfo, oauth2Token, nil
}

func (o *OpenIDCProvider) getClaimInfoFromToken(ctx context.Context, config *v32.OIDCConfig, token *oauth2.Token, userName string) (*ClaimInfo, error) {
var userInfo *oidc.UserInfo
var err error
var claimInfo *ClaimInfo

updatedContext, err := AddCertKeyToContext(ctx, config.Certificate, config.PrivateKey)
if err != nil {
return nil, err
}

provider, err := o.getOIDCProvider(updatedContext, config)
if err != nil {
return nil, err
}
oauthConfig := ConfigToOauthConfig(provider.Endpoint(), config)
var verifier = provider.Verifier(&oidc.Config{ClientID: config.ClientID})

// Valid will return false if access token is expired
if !token.Valid() {
// since token is not valid, the TokenSource func will attempt to refresh the access token
// if the refresh token has not expired
logrus.Debugf("[generic oidc] getUserInfo: attempting to refresh access token")
reusedToken, err := oauth2.ReuseTokenSource(token, oauthConfig.TokenSource(updatedContext, token)).Token()
if err != nil {
return nil, err
}
if !reflect.DeepEqual(token, reusedToken) {
err := o.UpdateToken(reusedToken, userName)
if err != nil {
return nil, fmt.Errorf("failed to update token: %w", err)
}
}
token = reusedToken
}

idToken, err := verifier.Verify(updatedContext, token.AccessToken)
if err != nil {
return nil, fmt.Errorf("failed to verify ID token: %w", err)
}
if err := idToken.Claims(&claimInfo); err != nil {
return nil, fmt.Errorf("failed to parse claims: %w", err)
}

if config.AcrValue != "" {
acrValue, err := parseACRFromAccessToken(token.AccessToken)
if err != nil {
return nil, err
}
if !isValidACR(acrValue, config.AcrValue) {
return nil, errors.New("failed due to invalid ACR")
}
}

logrus.Debugf("[generic oidc] getUserInfo: getting user info for user %s", userName)
userInfo, err = provider.UserInfo(updatedContext, oauthConfig.TokenSource(updatedContext, token))
if err != nil {
return nil, err
}
if err := userInfo.Claims(&claimInfo); err != nil {
return nil, err
}

return claimInfo, nil
}

func ConfigToOauthConfig(endpoint oauth2.Endpoint, config *v32.OIDCConfig) oauth2.Config {
var finalScopes []string
hasOIDCScope := strings.Contains(config.Scopes, oidc.ScopeOpenID)
Expand Down Expand Up @@ -557,7 +629,7 @@ func isValidACR(claimACR string, configuredACR string) bool {
}

if claimACR != configuredACR {
logrus.Infof("acr value in token does not match configured acr value")
logrus.Infof("[generic oidc] acr value in token does not match configured acr value")
return false
}
return true
Expand All @@ -568,15 +640,15 @@ func parseACRFromAccessToken(accessToken string) (string, error) {
// we already validated the incoming token
token, _, err := parser.ParseUnverified(accessToken, jwt.MapClaims{})
if err != nil {
return "", fmt.Errorf("failed to parse token: %w", err)
return "", fmt.Errorf("failed to parse JWT token: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", errors.New("invalid access token jwt.MapClaims format")
return "", errors.New("failed to parse claims in JWT token: invalid jwt.MapClaims format")
}
acrValue, found := claims["acr"].(string)
if !found {
return "", fmt.Errorf("acr claim invalid or not found in token: (acr=%v)", claims["acr"])
return "", fmt.Errorf("ACR claim invalid or not found in token: (acr=%v)", claims["acr"])
}
return acrValue, nil
}
Loading

0 comments on commit da0cf1d

Please sign in to comment.