Skip to content

Commit

Permalink
ACM-15764: fix access token renewal from the metrics collector (#1796)
Browse files Browse the repository at this point in the history
* add token file renewal

Signed-off-by: Thibault Mange <[email protected]>

* fix lint

Signed-off-by: Thibault Mange <[email protected]>

* clean

Signed-off-by: Thibault Mange <[email protected]>

* simplify renew strategy

Signed-off-by: Thibault Mange <[email protected]>

---------

Signed-off-by: Thibault Mange <[email protected]>
  • Loading branch information
thibaultmg authored Jan 30, 2025
1 parent eb91b06 commit d739ae7
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 14 deletions.
19 changes: 11 additions & 8 deletions collectors/metrics/pkg/forwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,19 @@ func (cfg Config) CreateFromClient(
fromClient.Transport = metricshttp.NewDebugRoundTripper(logger, fromClient.Transport)
}

if len(cfg.FromClientConfig.Token) == 0 && len(cfg.FromClientConfig.TokenFile) > 0 {
data, err := os.ReadFile(cfg.FromClientConfig.TokenFile)
if err != nil {
return nil, fmt.Errorf("unable to read from-token-file: %w", err)
}
cfg.FromClientConfig.Token = strings.TrimSpace(string(data))
if len(cfg.FromClientConfig.Token) > 0 && len(cfg.FromClientConfig.TokenFile) > 0 {
rlogger.Log(logger, rlogger.Info, "msg", "FromClient token is ignored as token file is specified")
}

if len(cfg.FromClientConfig.Token) > 0 {
fromClient.Transport = metricshttp.NewBearerRoundTripper(cfg.FromClientConfig.Token, fromClient.Transport)
if len(cfg.FromClientConfig.TokenFile) > 0 {
tf, err := NewTokenFile(context.Background(), logger, cfg.FromClientConfig.TokenFile, 2*time.Minute)
if err != nil {
return nil, fmt.Errorf("failed to create tokenFile: %w", err)
}
fromClient.Transport = metricshttp.NewBearerRoundTripper(tf.GetToken, fromClient.Transport)
} else if len(cfg.FromClientConfig.Token) > 0 {
getToken := func() string { return cfg.FromClientConfig.Token }
fromClient.Transport = metricshttp.NewBearerRoundTripper(getToken, fromClient.Transport)
}

return metricsclient.New(logger, metrics.clientMetrics, fromClient, cfg.LimitBytes, interval, "federate_from"), nil
Expand Down
162 changes: 162 additions & 0 deletions collectors/metrics/pkg/forwarder/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) Red Hat, Inc.
// Copyright Contributors to the Open Cluster Management project
// Licensed under the Apache License 2.0

package forwarder

import (
"context"
"errors"
"fmt"
"os"
"strings"
"sync"
"time"

"github.com/go-kit/log"
"github.com/golang-jwt/jwt/v5"

Check failure on line 17 in collectors/metrics/pkg/forwarder/token.go

View check run for this annotation

Red Hat Konflux / Red Hat Konflux / metrics-collector-acm-213-on-push

collectors/metrics/pkg/forwarder/token.go#L17

github.com/golang-jwt/jwt/[email protected]: Get "https://proxy.golang.org/github.com/golang-jwt/jwt/v5/@v/v5.2.1.zip": dial tcp: lookup proxy.golang.org on 172.30.0.10:53: dial udp 172.30.0.10:53: connect: network is unreachable
rlogger "github.com/stolostron/multicluster-observability-operator/collectors/metrics/pkg/logger"
)

const remainingDurationBeforeBackoff = 10 * time.Minute

var (
ErrEmptyTokenFilePath = errors.New("token file path is empty")
ErrEmptyToken = errors.New("token is empty")
ErrMissingExpirationClaim = errors.New("missing expiration claim")
)

type TokenFile struct {
filePath string
logger log.Logger
readBackoff time.Duration
token string
expiration time.Time
tokenMu sync.RWMutex
}

// NewTokenFile initiates a new TokenFile.
// It reads the token value from the provided filePath and the caller can access this value using the GetToken() method.
// The token value is automatically updated by re-reading the file as the token approaches expiration.
func NewTokenFile(ctx context.Context, logger log.Logger, filePath string, readBackoff time.Duration) (*TokenFile, error) {
if len(filePath) == 0 {
return nil, ErrEmptyTokenFilePath
}

tf := &TokenFile{
filePath: filePath,
logger: logger,
readBackoff: readBackoff,
}

// Initiate token value
if _, err := tf.renewTokenFromFile(); err != nil {
return nil, err
}

go tf.autoRenew(ctx)

return tf, nil
}

func (t *TokenFile) renewTokenFromFile() (bool, error) {
rawToken, err := os.ReadFile(t.filePath)
if err != nil {
return false, fmt.Errorf("failed to read token file: %w", err)
}

token := strings.TrimSpace(string(rawToken))
if len(token) == 0 {
return false, ErrEmptyToken
}

exp, err := parseTokenExpiration(token)
if err != nil {
return false, fmt.Errorf("failed to parse token expiration time: %w", err)
}

t.tokenMu.Lock()
defer t.tokenMu.Unlock()

if t.token == token {
return false, nil
}

t.token = token
t.expiration = exp

return true, nil
}

func (t *TokenFile) GetToken() string {
t.tokenMu.RLock()
defer t.tokenMu.RUnlock()
return t.token
}

// autoRenew automatically re-read the token file to update its value when it approaches the expiration time.
// The objective is to have a simple and robust strategy.
// Most lifetimes are 1y or 1h. Assuming that kubernetes renews the token when it reaches 80% of its lifetime, it is renewed 12 min before exp with 1h lifetime.
// The strategy is to read the token file every backoff duration until success, starting 10 minutes before expiration.
func (t *TokenFile) autoRenew(ctx context.Context) {
for {
t.tokenMu.RLock()
exp := t.expiration
t.tokenMu.RUnlock()

waitTime := computeWaitTime(exp, t.readBackoff, remainingDurationBeforeBackoff)
timer := time.NewTimer(waitTime)
rlogger.Log(t.logger, rlogger.Info, "msg", "Token renewal triggered", "waitTime", waitTime)
select {
case <-ctx.Done():
return
case <-timer.C:
}

ok, err := t.renewTokenFromFile()
if err != nil {
if waitTime <= t.readBackoff {
rlogger.Log(t.logger, rlogger.Error, "msg", "Failed to renew token", "error", err, "expiration", t.expiration, "path", t.filePath)
} else {
rlogger.Log(t.logger, rlogger.Warn, "msg", "Failed to renew token", "error", err, "expiration", t.expiration, "path", t.filePath)
}
}

if !ok && waitTime <= t.readBackoff {
rlogger.Log(t.logger, rlogger.Warn, "msg", "Failed to renew token while approaching expiration, same token read from file", "expiration", t.expiration, "path", t.filePath)
}

if ok {
rlogger.Log(t.logger, rlogger.Info, "msg", "Successful Token renewal from file")
}
}
}

func parseTokenExpiration(token string) (time.Time, error) {
parsedToken, _, err := jwt.NewParser().ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse JWT: %w", err)
}

exp, err := parsedToken.Claims.GetExpirationTime()
if err != nil {
return time.Time{}, fmt.Errorf("failed to get expiration time: %w", err)
}

if exp == nil {
return time.Time{}, ErrMissingExpirationClaim
}

return exp.Time, nil
}

func computeWaitTime(exiprationTime time.Time, backoff, remainingDurationBeforeBackoff time.Duration) time.Duration {
timeUntilExp := time.Until(exiprationTime)
timeToWait := timeUntilExp - remainingDurationBeforeBackoff - backoff

if timeToWait < backoff {
timeToWait = backoff
}

return timeToWait
}
133 changes: 133 additions & 0 deletions collectors/metrics/pkg/forwarder/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) Red Hat, Inc.
// Copyright Contributors to the Open Cluster Management project
// Licensed under the Apache License 2.0

package forwarder

import (
"context"
"crypto/ed25519"
"crypto/rand"
"fmt"
"os"
"path/filepath"
"testing"
"time"

"github.com/go-kit/log"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)

func TestTokenFile_Renewal(t *testing.T) {
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
assert.NoError(t, err)

// Create test token with close expiration time, and save it in a file
expiresAt := time.Now().Add(3 * time.Second)
claims := jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
tokenStr, err := token.SignedString(privateKey)
assert.NoError(t, err)
tmpFile := filepath.Join(t.TempDir(), "token")
err = os.WriteFile(tmpFile, []byte(tokenStr), 0644)
assert.NoError(t, err)

// Create token file with short backoff and wait to trigger failing and finally succesful reads
backoff := 1 * time.Second
tf, err := NewTokenFile(context.Background(), log.NewLogfmtLogger(os.Stderr), tmpFile, backoff)
assert.NoError(t, err)
assert.Equal(t, tokenStr, tf.GetToken())
time.Sleep(2 * backoff)

// Update token file
expiresAt = time.Now().Add(1 * time.Hour)
claims = jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
}
newToken := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
newTokenStr, err := newToken.SignedString(privateKey)
assert.NoError(t, err)
assert.NotEqual(t, tokenStr, newTokenStr)
err = os.WriteFile(tmpFile, []byte(newTokenStr), 0644)
assert.NoError(t, err)

// Check that the token has been updated
time.Sleep(2 * backoff)
assert.Equal(t, newTokenStr, tf.GetToken())
}

func TestTokenFile_ComputeWaitTime(t *testing.T) {
testCases := map[string]struct {
backoff time.Duration
expiration time.Time
minDuration time.Duration
expects time.Duration
}{
"no backoff": {
expiration: time.Now().Add(30 * time.Minute),
backoff: 2 * time.Minute,
minDuration: 10 * time.Minute,
expects: 20 * time.Minute,
},
"approaching remaining duration before backoff": {
expiration: time.Now().Add(12 * time.Minute),
backoff: 2 * time.Minute,
minDuration: 10 * time.Minute,
expects: 2 * time.Minute,
},
"below remaining duration before backoff": {
expiration: time.Now().Add(5 * time.Minute),
backoff: 2 * time.Minute,
minDuration: 10 * time.Minute,
expects: 2 * time.Minute,
},
"expired": {
expiration: time.Now().Add(-10 * time.Minute),
backoff: 2 * time.Minute,
minDuration: 10 * time.Minute,
expects: 2 * time.Minute,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
res := computeWaitTime(tc.expiration, tc.backoff, tc.minDuration)
assert.InEpsilon(t, tc.expects.Seconds(), res.Seconds(), 1, fmt.Sprintf("expected %.1f seconds, got %.1f seconds", tc.expects.Seconds(), res.Seconds()))
})
}
}

func TestTokenFile_ParseExpiration(t *testing.T) {
// Invalid token
_, err := parseTokenExpiration("aaa.bbb.ccc")
assert.Error(t, err)

// No expiration
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
assert.NoError(t, err)
claims := jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now()),
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
tokenStr, err := token.SignedString(privateKey)
assert.NoError(t, err)
assert.NotEmpty(t, tokenStr)
_, err = parseTokenExpiration(tokenStr)
assert.Error(t, err)

// Valid expiration
expiresAt := time.Unix(1737557854, 0)
claims = jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
}
token = jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
tokenStr, err = token.SignedString(privateKey)
assert.NoError(t, err)
assert.NotEmpty(t, tokenStr)
expiration, err := parseTokenExpiration(tokenStr)
assert.NoError(t, err)
assert.Equal(t, expiresAt.Unix(), expiration.Unix())
}
12 changes: 7 additions & 5 deletions collectors/metrics/pkg/http/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@ import (
"github.com/stolostron/multicluster-observability-operator/collectors/metrics/pkg/logger"
)

type tokenGetter func() string

type bearerRoundTripper struct {
token string
wrapper http.RoundTripper
getToken tokenGetter
wrapper http.RoundTripper
}

func NewBearerRoundTripper(token string, rt http.RoundTripper) http.RoundTripper {
return &bearerRoundTripper{token: token, wrapper: rt}
func NewBearerRoundTripper(token tokenGetter, rt http.RoundTripper) http.RoundTripper {
return &bearerRoundTripper{getToken: token, wrapper: rt}
}

func (rt *bearerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", rt.token))
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", rt.getToken()))
return rt.wrapper.RoundTrip(req)
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/go-kit/log v0.2.1
github.com/go-logr/logr v1.4.2
github.com/gogo/protobuf v1.3.2
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/golang/protobuf v1.5.4
github.com/golang/snappy v0.0.4
github.com/google/go-cmp v0.6.0
Expand Down Expand Up @@ -115,7 +116,6 @@ require (
github.com/go-openapi/validate v0.23.0 // indirect
github.com/gobuffalo/flect v1.0.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.1 // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
github.com/golang/glog v1.2.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/mock v1.6.0 // indirect
Expand Down

0 comments on commit d739ae7

Please sign in to comment.