Skip to content

Commit

Permalink
TT-13271, fix for token metadata not being cached
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-tyk committed Nov 1, 2024
1 parent 4a14e3a commit 48e627d
Showing 1 changed file with 85 additions and 17 deletions.
102 changes: 85 additions & 17 deletions gateway/mw_oauth2_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -57,14 +58,21 @@ type upstreamOAuthPasswordCache struct {
func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
cacheKey := generatePasswordOAuthCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID)

tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
if err != nil {
return "", err
}

if tokenString != "" {
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString)
return decryptedToken, nil
if tokenData != "" {
if tokenData != "" {
tokenContents, err := unmarshalTokenData(tokenData)
if err != nil {
return "", err
}
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, tokenContents.ExtraMetadata)
return decryptedToken, nil
}
}

token, err := cache.obtainToken(r.Context(), OAuthSpec)
Expand All @@ -73,10 +81,15 @@ func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *Up
}

encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, token)
tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata)
if err != nil {
return "", err
}
metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, metadataMap)

ttl := time.Until(token.Expiry)
if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil {
if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil {
return "", err
}

Expand Down Expand Up @@ -271,16 +284,26 @@ func generateClientCredentialsCacheKey(config apidef.UpstreamOAuth, apiId string
return hex.EncodeToString(hash.Sum(nil))
}

type TokenData struct {
Token string `json:"token"`
ExtraMetadata map[string]interface{} `json:"extra_metadata"`
}

func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
cacheKey := generateClientCredentialsCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID)

tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
if err != nil {
return "", err
}

if tokenString != "" {
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString)
if tokenData != "" {
tokenContents, err := unmarshalTokenData(tokenData)
if err != nil {
return "", err
}
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, tokenContents.ExtraMetadata)
return decryptedToken, nil
}

Expand All @@ -290,24 +313,69 @@ func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAut
}

encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, token)
tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata)
if err != nil {
return "", err
}
metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, metadataMap)

ttl := time.Until(token.Expiry)
if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil {
if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil {
return "", err
}

return token.AccessToken, nil
}

func setExtraMetadata(r *http.Request, keyList []string, token *oauth2.Token) {
func createTokenDataBytes(encryptedToken string, token *oauth2.Token, extraMetadataKeys []string) ([]byte, error) {
td := TokenData{
Token: encryptedToken,
ExtraMetadata: buildMetadataMap(token, extraMetadataKeys),
}
return json.Marshal(td)
}

func unmarshalTokenData(tokenData string) (TokenData, error) {
var tokenContents TokenData
err := json.Unmarshal([]byte(tokenData), &tokenContents)
if err != nil {
return TokenData{}, fmt.Errorf("failed to unmarshal token data: %w", err)
}
return tokenContents, nil
}

func buildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[string]interface{} {
metadataMap := make(map[string]interface{})
for _, key := range extraMetadataKeys {
if val := token.Extra(key); val != "" {
metadataMap[key] = val
}
}
return metadataMap
}

//func setExtraMetadata(r *http.Request, keyList []string, token *oauth2.Token) {
// contextDataObject := ctxGetData(r)
// if contextDataObject == nil {
// contextDataObject = make(map[string]interface{})
// }
// for _, key := range keyList {
// val := token.Extra(key)
// if val != "" {
// contextDataObject[key] = val
// }
// }
// ctxSetData(r, contextDataObject)
//}

func setExtraMetadata(r *http.Request, keyList []string, token map[string]interface{}) {
contextDataObject := ctxGetData(r)
if contextDataObject == nil {
contextDataObject = make(map[string]interface{})
}
for _, key := range keyList {
val := token.Extra(key)
if val != "" {
if val, ok := token[key]; ok && val != "" {
contextDataObject[key] = val
}
}
Expand All @@ -318,13 +386,13 @@ func retryGetKeyAndLock(cacheKey string, cache *storage.RedisCluster) (string, e
const maxRetries = 10
const retryDelay = 100 * time.Millisecond

var token string
var tokenData string
var err error

for i := 0; i < maxRetries; i++ {
token, err = cache.GetKey(cacheKey)
tokenData, err = cache.GetKey(cacheKey)
if err == nil {
return token, nil
return tokenData, nil
}

lockKey := cacheKey + ":lock"
Expand Down

0 comments on commit 48e627d

Please sign in to comment.