Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream revert revert auth token fix #5407

Merged
merged 15 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flytectl/cmd/core/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ func generateCommandFunc(cmdEntry CommandEntry) func(cmd *cobra.Command, args []
cmdCtx := NewCommandContextNoClient(cmd.OutOrStdout())
if !cmdEntry.DisableFlyteClient {
clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).
WithTokenCache(pkce.TokenCacheKeyringProvider{
ServiceUser: fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser),
ServiceName: pkce.KeyRingServiceName,
}).Build(ctx)
WithTokenCache(pkce.NewTokenCacheKeyringProvider(
pkce.KeyRingServiceName,
fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser),
)).Build(ctx)
if err != nil {
return err
}
Expand Down
86 changes: 80 additions & 6 deletions flytectl/pkg/pkce/token_cache_keyring.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,88 @@
package pkce

import (
"context"
"encoding/json"
"fmt"
"sync"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flytestdlib/logger"

"github.com/zalando/go-keyring"
"golang.org/x/oauth2"
)

const (
KeyRingServiceUser = "flytectl-user"
KeyRingServiceName = "flytectl"
)

// TokenCacheKeyringProvider wraps the logic to save and retrieve tokens from the OS's keyring implementation.
type TokenCacheKeyringProvider struct {
ServiceName string
ServiceUser string
mu *sync.Mutex
condLocker *cache.NoopLocker
cond *sync.Cond
}

const (
KeyRingServiceUser = "flytectl-user"
KeyRingServiceName = "flytectl"
)
func (t *TokenCacheKeyringProvider) PurgeIfEquals(existing *oauth2.Token) (bool, error) {
if existingBytes, err := json.Marshal(existing); err != nil {
return false, fmt.Errorf("unable to marshal token to save in cache due to %w", err)
} else if tokenJSON, err := keyring.Get(t.ServiceName, t.ServiceUser); err != nil {
logger.Warnf(context.Background(), "unable to read token from cache but not failing the purge as the token might not have been saved at all. Error: %v", err)
return true, nil
} else if tokenJSON != string(existingBytes) {
return false, nil

Check warning on line 37 in flytectl/pkg/pkce/token_cache_keyring.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_keyring.go#L30-L37

Added lines #L30 - L37 were not covered by tests
}

_ = keyring.Delete(t.ServiceName, t.ServiceUser)
return true, nil

Check warning on line 41 in flytectl/pkg/pkce/token_cache_keyring.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_keyring.go#L40-L41

Added lines #L40 - L41 were not covered by tests
}

func (t TokenCacheKeyringProvider) SaveToken(token *oauth2.Token) error {
func (t *TokenCacheKeyringProvider) Lock() {
t.mu.Lock()

Check warning on line 45 in flytectl/pkg/pkce/token_cache_keyring.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_keyring.go#L44-L45

Added lines #L44 - L45 were not covered by tests
}

func (t *TokenCacheKeyringProvider) Unlock() {
t.mu.Unlock()

Check warning on line 49 in flytectl/pkg/pkce/token_cache_keyring.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_keyring.go#L48-L49

Added lines #L48 - L49 were not covered by tests
}

// TryLock the cache.
func (t *TokenCacheKeyringProvider) TryLock() bool {
return t.mu.TryLock()

Check warning on line 54 in flytectl/pkg/pkce/token_cache_keyring.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_keyring.go#L53-L54

Added lines #L53 - L54 were not covered by tests
}

// CondWait adds the current go routine to the condition waitlist and waits for another go routine to notify using CondBroadcast
// The current usage is that one who was able to acquire the lock using TryLock is the one who gets a valid token and notifies all the waitlist requesters so that they can use the new valid token.
// It also locks the Locker in the condition variable as the semantics of Wait is that it unlocks the Locker after adding
// the consumer to the waitlist and before blocking on notification.
// We use the condLocker which is noOp locker to get added to waitlist for notifications.
// The underlying notifcationList doesn't need to be guarded as it implementation is atomic and is thread safe
// Refer https://go.dev/src/runtime/sema.go
// Following is the function and its comments
// notifyListAdd adds the caller to a notify list such that it can receive
// notifications. The caller must eventually call notifyListWait to wait for
// such a notification, passing the returned ticket number.
//
// func notifyListAdd(l *notifyList) uint32 {
// // This may be called concurrently, for example, when called from
// // sync.Cond.Wait while holding a RWMutex in read mode.
// return l.wait.Add(1) - 1
// }
func (t *TokenCacheKeyringProvider) CondWait() {
t.condLocker.Lock()
t.cond.Wait()
t.condLocker.Unlock()

Check warning on line 77 in flytectl/pkg/pkce/token_cache_keyring.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_keyring.go#L74-L77

Added lines #L74 - L77 were not covered by tests
}

// CondBroadcast broadcasts the condition.
func (t *TokenCacheKeyringProvider) CondBroadcast() {
t.cond.Broadcast()

Check warning on line 82 in flytectl/pkg/pkce/token_cache_keyring.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_keyring.go#L81-L82

Added lines #L81 - L82 were not covered by tests
}

func (t *TokenCacheKeyringProvider) SaveToken(token *oauth2.Token) error {
var tokenBytes []byte
if token.AccessToken == "" {
return fmt.Errorf("cannot save empty token with expiration %v", token.Expiry)
Expand All @@ -38,7 +101,7 @@
return nil
}

func (t TokenCacheKeyringProvider) GetToken() (*oauth2.Token, error) {
func (t *TokenCacheKeyringProvider) GetToken() (*oauth2.Token, error) {
// get saved token
tokenJSON, err := keyring.Get(t.ServiceName, t.ServiceUser)
if len(tokenJSON) == 0 {
Expand All @@ -56,3 +119,14 @@

return &token, nil
}

func NewTokenCacheKeyringProvider(serviceName, serviceUser string) *TokenCacheKeyringProvider {
condLocker := &cache.NoopLocker{}
return &TokenCacheKeyringProvider{
mu: &sync.Mutex{},
condLocker: condLocker,
cond: sync.NewCond(condLocker),
ServiceName: serviceName,
ServiceUser: serviceUser,

Check warning on line 130 in flytectl/pkg/pkce/token_cache_keyring.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_keyring.go#L123-L130

Added lines #L123 - L130 were not covered by tests
}
}
52 changes: 45 additions & 7 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server.
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache,
perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
Expand All @@ -42,11 +43,17 @@

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return err
return fmt.Errorf("failed to get token source. Error: %w", err)

Check warning on line 46 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L46

Added line #L46 was not covered by tests
}

_, err = tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to issue token. Error: %w", err)

Check warning on line 51 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L51

Added line #L51 was not covered by tests
}

wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey)
perRPCCredentials.Store(wrappedTokenSource)

return nil
}

Expand Down Expand Up @@ -134,19 +141,50 @@
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture)

// If there is already a token in the cache (e.g. key-ring), we should use it immediately...
t, _ := tokenCache.GetToken()
if t != nil {
err := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
}
}

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err)

if st, ok := status.FromError(err); ok {
// If the error we receive from executing the request expects
if shouldAttemptToAuthenticate(st.Code()) {
logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
if newErr != nil {
return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr)
err = func() error {
if !tokenCache.TryLock() {
tokenCache.CondWait()
return nil

Check warning on line 163 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L162-L163

Added lines #L162 - L163 were not covered by tests
}

defer tokenCache.Unlock()
_, err := tokenCache.PurgeIfEquals(t)
if err != nil && !errors.Is(err, cache.ErrNotFound) {
logger.Errorf(ctx, "Failed to purge cache. Error [%v]", err)
return fmt.Errorf("failed to purge cache. Error: %w", err)

Check warning on line 170 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L169-L170

Added lines #L169 - L170 were not covered by tests
}

logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
if newErr != nil {
errString := fmt.Sprintf("authentication error! Original Error: %v, Auth Error: %v", err, newErr)
logger.Errorf(ctx, errString)
return fmt.Errorf(errString)

Check warning on line 178 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L176-L178

Added lines #L176 - L178 were not covered by tests
}

tokenCache.CondBroadcast()
return nil
}()

if err != nil {
return err

Check warning on line 186 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L186

Added line #L186 was not covered by tests
}

return invoker(ctx, method, req, reply, cc, opts...)
}
}
Expand Down
Loading
Loading