Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max duration and metrics prefix to token cache #877

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cache/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type storeOptions struct {
interval time.Duration
registerer prometheus.Registerer
metricsPrefix string
maxDuration time.Duration
involvedObject *InvolvedObject
debugKey string
debugValueFunc func(any) any
Expand Down Expand Up @@ -88,6 +89,14 @@ func WithMetricsPrefix(prefix string) Options {
}
}

// WithMaxDuration sets the maximum duration for the cache items.
func WithMaxDuration(duration time.Duration) Options {
return func(o *storeOptions) error {
o.maxDuration = duration
return nil
}
}

// WithInvolvedObject sets the involved object for the cache metrics.
func WithInvolvedObject(kind, name, namespace string) Options {
return func(o *storeOptions) error {
Expand Down
24 changes: 22 additions & 2 deletions cache/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import (
"time"
)

// TokenMaxDuration is the maximum duration that a token can have in the
// TokenCache. This is used to cap the duration of tokens to avoid storing
// tokens that are valid for too long.
const TokenMaxDuration = time.Hour

// Token is an interface that represents an access token that can be used
// to authenticate with a cloud provider. The only common method is to get the
// duration of the token, because different providers may have different ways to
Expand All @@ -45,7 +50,8 @@ type Token interface {
// lifetime, which is the same strategy used by kubelet for rotating
// ServiceAccount tokens.
type TokenCache struct {
cache *LRU[*tokenItem]
cache *LRU[*tokenItem]
maxDuration time.Duration
}

type tokenItem struct {
Expand All @@ -55,9 +61,19 @@ type tokenItem struct {
}

// NewTokenCache returns a new TokenCache with the given capacity.
// The metrics prefix is always set to "token_".
func NewTokenCache(capacity int, opts ...Options) *TokenCache {
o := storeOptions{maxDuration: TokenMaxDuration}
o.apply(opts...)

if o.maxDuration > TokenMaxDuration {
o.maxDuration = TokenMaxDuration
}

opts = append(opts, WithMetricsPrefix("token_"))
cache, _ := NewLRU[*tokenItem](capacity, opts...)
return &TokenCache{cache: cache}

return &TokenCache{cache, o.maxDuration}
}

// GetOrSet returns the token for the given key if present and not expired, or
Expand Down Expand Up @@ -112,6 +128,10 @@ func (c *TokenCache) newItem(token Token) *tokenItem {
// Ref: https://github.com/kubernetes/kubernetes/blob/4032177faf21ae2f99a2012634167def2376b370/pkg/kubelet/token/token_manager.go#L172-L174
d := (token.GetDuration() * 8) / 10

if m := c.maxDuration; d > m {
d = m
}

mono := time.Now().Add(d)
unix := time.Unix(mono.Unix(), 0)

Expand Down
101 changes: 95 additions & 6 deletions cache/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package cache_test

import (
"context"
"fmt"
"testing"
"time"

Expand All @@ -35,6 +36,8 @@ func (t *testToken) GetDuration() time.Duration {
}

func TestTokenCache_Lifecycle(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()
Expand All @@ -48,19 +51,105 @@ func TestTokenCache_Lifecycle(t *testing.T) {
g.Expect(retrieved).To(BeFalse())
g.Expect(err).To(BeNil())

time.Sleep(4 * time.Second)
token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 2 * time.Second}))
g.Expect(retrieved).To(BeTrue())

time.Sleep(2 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 100 * time.Second}, nil
return &testToken{duration: time.Hour}, nil
})
g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second}))
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())
g.Expect(err).To(BeNil())
}

time.Sleep(2 * time.Second)
func TestTokenCache_80PercentLifetime(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc := cache.NewTokenCache(1)

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 5 * time.Second}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second}))
g.Expect(retrieved).To(BeFalse())

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })
g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second}))

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second}))
g.Expect(retrieved).To(BeTrue())
g.Expect(err).To(BeNil())

time.Sleep(4 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: time.Hour}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())
}

func TestTokenCache_MaxDuration(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc := cache.NewTokenCache(1, cache.WithMaxDuration(time.Second))

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: time.Hour}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeTrue())

time.Sleep(2 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 10 * time.Millisecond}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 10 * time.Millisecond}))
g.Expect(retrieved).To(BeFalse())
}

func TestTokenCache_GetOrSet_Error(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc := cache.NewTokenCache(1)

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return nil, fmt.Errorf("failed")
})

g.Expect(err).To(HaveOccurred())
g.Expect(err).To(MatchError("failed"))
g.Expect(token).To(BeNil())
g.Expect(retrieved).To(BeFalse())
}
Loading