diff --git a/flytectl/cmd/core/cmd.go b/flytectl/cmd/core/cmd.go index 989f4b7ebb..b2bd77c317 100644 --- a/flytectl/cmd/core/cmd.go +++ b/flytectl/cmd/core/cmd.go @@ -73,10 +73,10 @@ func generateCommandFunc(cmdEntry CommandEntry) func(cmd *cobra.Command, args [] cmdCtx := NewCommandContextNoClient(cmd.OutOrStdout()) if !cmdEntry.DisableFlyteClient { clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)). - WithTokenCache(pkce.TokenCacheKeyringProvider{ - ServiceUser: fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser), - ServiceName: pkce.KeyRingServiceName, - }).Build(ctx) + WithTokenCache(pkce.NewTokenCacheKeyringProvider( + pkce.KeyRingServiceName, + fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser), + )).Build(ctx) if err != nil { return err } diff --git a/flytectl/pkg/pkce/token_cache_keyring.go b/flytectl/pkg/pkce/token_cache_keyring.go index 119fea5033..afcfa74db5 100644 --- a/flytectl/pkg/pkce/token_cache_keyring.go +++ b/flytectl/pkg/pkce/token_cache_keyring.go @@ -1,25 +1,88 @@ package pkce import ( + "context" "encoding/json" "fmt" + "sync" + + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache" + "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/zalando/go-keyring" "golang.org/x/oauth2" ) +const ( + KeyRingServiceUser = "flytectl-user" + KeyRingServiceName = "flytectl" +) + // TokenCacheKeyringProvider wraps the logic to save and retrieve tokens from the OS's keyring implementation. type TokenCacheKeyringProvider struct { ServiceName string ServiceUser string + mu *sync.Mutex + condLocker *cache.NoopLocker + cond *sync.Cond } -const ( - KeyRingServiceUser = "flytectl-user" - KeyRingServiceName = "flytectl" -) +func (t *TokenCacheKeyringProvider) PurgeIfEquals(existing *oauth2.Token) (bool, error) { + if existingBytes, err := json.Marshal(existing); err != nil { + return false, fmt.Errorf("unable to marshal token to save in cache due to %w", err) + } else if tokenJSON, err := keyring.Get(t.ServiceName, t.ServiceUser); err != nil { + logger.Warnf(context.Background(), "unable to read token from cache but not failing the purge as the token might not have been saved at all. Error: %v", err) + return true, nil + } else if tokenJSON != string(existingBytes) { + return false, nil + } + + _ = keyring.Delete(t.ServiceName, t.ServiceUser) + return true, nil +} -func (t TokenCacheKeyringProvider) SaveToken(token *oauth2.Token) error { +func (t *TokenCacheKeyringProvider) Lock() { + t.mu.Lock() +} + +func (t *TokenCacheKeyringProvider) Unlock() { + t.mu.Unlock() +} + +// TryLock the cache. +func (t *TokenCacheKeyringProvider) TryLock() bool { + return t.mu.TryLock() +} + +// CondWait adds the current go routine to the condition waitlist and waits for another go routine to notify using CondBroadcast +// The current usage is that one who was able to acquire the lock using TryLock is the one who gets a valid token and notifies all the waitlist requesters so that they can use the new valid token. +// It also locks the Locker in the condition variable as the semantics of Wait is that it unlocks the Locker after adding +// the consumer to the waitlist and before blocking on notification. +// We use the condLocker which is noOp locker to get added to waitlist for notifications. +// The underlying notifcationList doesn't need to be guarded as it implementation is atomic and is thread safe +// Refer https://go.dev/src/runtime/sema.go +// Following is the function and its comments +// notifyListAdd adds the caller to a notify list such that it can receive +// notifications. The caller must eventually call notifyListWait to wait for +// such a notification, passing the returned ticket number. +// +// func notifyListAdd(l *notifyList) uint32 { +// // This may be called concurrently, for example, when called from +// // sync.Cond.Wait while holding a RWMutex in read mode. +// return l.wait.Add(1) - 1 +// } +func (t *TokenCacheKeyringProvider) CondWait() { + t.condLocker.Lock() + t.cond.Wait() + t.condLocker.Unlock() +} + +// CondBroadcast broadcasts the condition. +func (t *TokenCacheKeyringProvider) CondBroadcast() { + t.cond.Broadcast() +} + +func (t *TokenCacheKeyringProvider) SaveToken(token *oauth2.Token) error { var tokenBytes []byte if token.AccessToken == "" { return fmt.Errorf("cannot save empty token with expiration %v", token.Expiry) @@ -38,7 +101,7 @@ func (t TokenCacheKeyringProvider) SaveToken(token *oauth2.Token) error { return nil } -func (t TokenCacheKeyringProvider) GetToken() (*oauth2.Token, error) { +func (t *TokenCacheKeyringProvider) GetToken() (*oauth2.Token, error) { // get saved token tokenJSON, err := keyring.Get(t.ServiceName, t.ServiceUser) if len(tokenJSON) == 0 { @@ -56,3 +119,14 @@ func (t TokenCacheKeyringProvider) GetToken() (*oauth2.Token, error) { return &token, nil } + +func NewTokenCacheKeyringProvider(serviceName, serviceUser string) *TokenCacheKeyringProvider { + condLocker := &cache.NoopLocker{} + return &TokenCacheKeyringProvider{ + mu: &sync.Mutex{}, + condLocker: condLocker, + cond: sync.NewCond(condLocker), + ServiceName: serviceName, + ServiceUser: serviceUser, + } +} diff --git a/flyteidl/clients/go/admin/auth_interceptor.go b/flyteidl/clients/go/admin/auth_interceptor.go index 8a0024b319..4cebf6440f 100644 --- a/flyteidl/clients/go/admin/auth_interceptor.go +++ b/flyteidl/clients/go/admin/auth_interceptor.go @@ -20,7 +20,8 @@ const ProxyAuthorizationHeader = "proxy-authorization" // MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server. // Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values. -func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error { +func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, + perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error { authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture) if err != nil { return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err) @@ -42,11 +43,17 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T tokenSource, err := tokenSourceProvider.GetTokenSource(ctx) if err != nil { - return err + return fmt.Errorf("failed to get token source. Error: %w", err) + } + + _, err = tokenSource.Token() + if err != nil { + return fmt.Errorf("failed to issue token. Error: %w", err) } wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey) perRPCCredentials.Store(wrappedTokenSource) + return nil } @@ -134,6 +141,15 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture) + // If there is already a token in the cache (e.g. key-ring), we should use it immediately... + t, _ := tokenCache.GetToken() + if t != nil { + err := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) + if err != nil { + return fmt.Errorf("failed to materialize credentials. Error: %v", err) + } + } + err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err) @@ -141,12 +157,34 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut if st, ok := status.FromError(err); ok { // If the error we receive from executing the request expects if shouldAttemptToAuthenticate(st.Code()) { - logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) - newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) - if newErr != nil { - return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr) + err = func() error { + if !tokenCache.TryLock() { + tokenCache.CondWait() + return nil + } + + defer tokenCache.Unlock() + _, err := tokenCache.PurgeIfEquals(t) + if err != nil && !errors.Is(err, cache.ErrNotFound) { + logger.Errorf(ctx, "Failed to purge cache. Error [%v]", err) + return fmt.Errorf("failed to purge cache. Error: %w", err) + } + + logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) + newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) + if newErr != nil { + errString := fmt.Sprintf("authentication error! Original Error: %v, Auth Error: %v", err, newErr) + logger.Errorf(ctx, errString) + return fmt.Errorf(errString) + } + + tokenCache.CondBroadcast() + return nil + }() + + if err != nil { + return err } - return invoker(ctx, method, req, reply, cc, opts...) } } diff --git a/flyteidl/clients/go/admin/auth_interceptor_test.go b/flyteidl/clients/go/admin/auth_interceptor_test.go index ce99c99270..10c96625b7 100644 --- a/flyteidl/clients/go/admin/auth_interceptor_test.go +++ b/flyteidl/clients/go/admin/auth_interceptor_test.go @@ -2,13 +2,14 @@ package admin import ( "context" + "encoding/json" "errors" "fmt" "io" "net" "net/http" - "net/http/httptest" "net/url" + "os" "strings" "sync" "testing" @@ -31,10 +32,11 @@ import ( // authMetadataServer is a fake AuthMetadataServer that takes in an AuthMetadataServer implementation (usually one // initialized through mockery) and starts a local server that uses it to respond to grpc requests. type authMetadataServer struct { - s *httptest.Server t testing.TB - port int + grpcPort int + httpPort int grpcServer *grpc.Server + httpServer *http.Server netListener net.Listener impl service.AuthMetadataServiceServer lck *sync.RWMutex @@ -70,27 +72,49 @@ func (s authMetadataServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } +func (s *authMetadataServer) tokenHandler(w http.ResponseWriter, r *http.Request) { + tokenJSON := []byte(`{"access_token": "exampletoken", "token_type": "bearer"}`) + w.Header().Set("Content-Type", "application/json") + _, err := w.Write(tokenJSON) + assert.NoError(s.t, err) +} + func (s *authMetadataServer) Start(_ context.Context) error { s.lck.Lock() defer s.lck.Unlock() /***** Set up the server serving channelz service. *****/ - lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", s.port)) + + lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", s.grpcPort)) if err != nil { - return fmt.Errorf("failed to listen on port [%v]: %w", s.port, err) + return fmt.Errorf("failed to listen on port [%v]: %w", s.grpcPort, err) } + s.netListener = lis grpcS := grpc.NewServer() service.RegisterAuthMetadataServiceServer(grpcS, s) go func() { + defer grpcS.Stop() _ = grpcS.Serve(lis) - //assert.NoError(s.t, err) }() - s.grpcServer = grpcS - s.netListener = lis + mux := http.NewServeMux() + // Attach the handler to the /oauth2/token path + mux.HandleFunc("/oauth2/token", s.tokenHandler) + + //nolint:gosec + s.httpServer = &http.Server{ + Addr: fmt.Sprintf("localhost:%d", s.httpPort), + Handler: mux, + } - s.s = httptest.NewServer(s) + go func() { + defer s.httpServer.Close() + err := s.httpServer.ListenAndServe() + if err != nil { + panic(err) + } + }() return nil } @@ -98,25 +122,30 @@ func (s *authMetadataServer) Start(_ context.Context) error { func (s *authMetadataServer) Close() { s.lck.RLock() defer s.lck.RUnlock() - s.grpcServer.Stop() - s.s.Close() } -func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServiceServer) *authMetadataServer { +func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl service.AuthMetadataServiceServer) *authMetadataServer { return &authMetadataServer{ - port: port, - t: t, - impl: impl, - lck: &sync.RWMutex{}, + grpcPort: grpcPort, + httpPort: httpPort, + t: t, + impl: impl, + lck: &sync.RWMutex{}, } } func Test_newAuthInterceptor(t *testing.T) { + plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json") + var tokenData oauth2.Token + err := json.Unmarshal(plan, &tokenData) + assert.NoError(t, err) t.Run("Other Error", func(t *testing.T) { f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() - interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f, p) + mockTokenCache := &mocks.TokenCache{} + mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) + interceptor := NewAuthInterceptor(&Config{}, mockTokenCache, f, p) otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Canceled, "").Err() } @@ -129,35 +158,43 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - port := rand.IntnRange(10000, 60000) + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{ - AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), - TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), - JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), + AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort), + TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), + JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort), }, nil) + m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{ Scopes: []string{"all"}, }, nil) - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) + c.OnTryLockMatch().Return(true) + c.OnSaveTokenMatch(mock.Anything).Return(nil) + c.On("CondBroadcast").Return() + c.On("Unlock").Return() + c.OnPurgeIfEqualsMatch(mock.Anything).Return(true, nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f, p) + }, c, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Unauthenticated, "").Err() } - err = interceptor(ctx, "POST", nil, nil, nil, unauthenticated) assert.Error(t, err) assert.Truef(t, f.IsInitialized(), "PerRPCCredentialFuture should be initialized") @@ -169,24 +206,26 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - port := rand.IntnRange(10000, 60000) + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() - + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f, p) + }, c, f, p) authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return nil } @@ -201,33 +240,39 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - port := rand.IntnRange(10000, 60000) + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{ - AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), - TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), - JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), + AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort), + TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), + JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort), }, nil) m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{ Scopes: []string{"all"}, }, nil) - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) + c.OnTryLockMatch().Return(true) + c.OnSaveTokenMatch(mock.Anything).Return(nil) + c.OnPurgeIfEqualsMatch(mock.Anything).Return(true, nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f, p) + }, c, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Aborted, "").Err() } @@ -239,17 +284,21 @@ func Test_newAuthInterceptor(t *testing.T) { } func TestMaterializeCredentials(t *testing.T) { - port := rand.IntnRange(10000, 60000) t.Run("No oauth2 metadata endpoint or Public client config lookup", func(t *testing.T) { + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) + c.OnSaveTokenMatch(mock.Anything).Return(nil) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get public client config")) - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() @@ -259,24 +308,29 @@ func TestMaterializeCredentials(t *testing.T) { Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), + TokenURL: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), Scopes: []string{"all"}, Audience: "http://localhost:30081", AuthorizationHeader: "authorization", - }, &mocks.TokenCache{}, f, p) + }, c, f, p) assert.NoError(t, err) }) t.Run("Failed to fetch client metadata", func(t *testing.T) { + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) + c.OnSaveTokenMatch(mock.Anything).Return(nil) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) failedPublicClientConfigLookup := errors.New("expected err") m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, failedPublicClientConfigLookup) - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() @@ -286,9 +340,9 @@ func TestMaterializeCredentials(t *testing.T) { Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), + TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", httpPort), Scopes: []string{"all"}, - }, &mocks.TokenCache{}, f, p) + }, c, f, p) assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err") }) } diff --git a/flyteidl/clients/go/admin/cache/mocks/token_cache.go b/flyteidl/clients/go/admin/cache/mocks/token_cache.go index 0af58b381f..88a1bef81c 100644 --- a/flyteidl/clients/go/admin/cache/mocks/token_cache.go +++ b/flyteidl/clients/go/admin/cache/mocks/token_cache.go @@ -12,6 +12,16 @@ type TokenCache struct { mock.Mock } +// CondBroadcast provides a mock function with given fields: +func (_m *TokenCache) CondBroadcast() { + _m.Called() +} + +// CondWait provides a mock function with given fields: +func (_m *TokenCache) CondWait() { + _m.Called() +} + type TokenCache_GetToken struct { *mock.Call } @@ -53,6 +63,50 @@ func (_m *TokenCache) GetToken() (*oauth2.Token, error) { return r0, r1 } +// Lock provides a mock function with given fields: +func (_m *TokenCache) Lock() { + _m.Called() +} + +type TokenCache_PurgeIfEquals struct { + *mock.Call +} + +func (_m TokenCache_PurgeIfEquals) Return(_a0 bool, _a1 error) *TokenCache_PurgeIfEquals { + return &TokenCache_PurgeIfEquals{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *TokenCache) OnPurgeIfEquals(t *oauth2.Token) *TokenCache_PurgeIfEquals { + c_call := _m.On("PurgeIfEquals", t) + return &TokenCache_PurgeIfEquals{Call: c_call} +} + +func (_m *TokenCache) OnPurgeIfEqualsMatch(matchers ...interface{}) *TokenCache_PurgeIfEquals { + c_call := _m.On("PurgeIfEquals", matchers...) + return &TokenCache_PurgeIfEquals{Call: c_call} +} + +// PurgeIfEquals provides a mock function with given fields: t +func (_m *TokenCache) PurgeIfEquals(t *oauth2.Token) (bool, error) { + ret := _m.Called(t) + + var r0 bool + if rf, ok := ret.Get(0).(func(*oauth2.Token) bool); ok { + r0 = rf(t) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(*oauth2.Token) error); ok { + r1 = rf(t) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + type TokenCache_SaveToken struct { *mock.Call } @@ -84,3 +138,40 @@ func (_m *TokenCache) SaveToken(token *oauth2.Token) error { return r0 } + +type TokenCache_TryLock struct { + *mock.Call +} + +func (_m TokenCache_TryLock) Return(_a0 bool) *TokenCache_TryLock { + return &TokenCache_TryLock{Call: _m.Call.Return(_a0)} +} + +func (_m *TokenCache) OnTryLock() *TokenCache_TryLock { + c_call := _m.On("TryLock") + return &TokenCache_TryLock{Call: c_call} +} + +func (_m *TokenCache) OnTryLockMatch(matchers ...interface{}) *TokenCache_TryLock { + c_call := _m.On("TryLock", matchers...) + return &TokenCache_TryLock{Call: c_call} +} + +// TryLock provides a mock function with given fields: +func (_m *TokenCache) TryLock() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Unlock provides a mock function with given fields: +func (_m *TokenCache) Unlock() { + _m.Called() +} diff --git a/flyteidl/clients/go/admin/cache/token_cache.go b/flyteidl/clients/go/admin/cache/token_cache.go index e4e2b7e17f..f2d55fc0dd 100644 --- a/flyteidl/clients/go/admin/cache/token_cache.go +++ b/flyteidl/clients/go/admin/cache/token_cache.go @@ -1,14 +1,40 @@ package cache -import "golang.org/x/oauth2" +import ( + "fmt" + + "golang.org/x/oauth2" +) //go:generate mockery -all -case=underscore +var ( + ErrNotFound = fmt.Errorf("secret not found in keyring") +) + // TokenCache defines the interface needed to cache and retrieve oauth tokens. type TokenCache interface { // SaveToken saves the token securely to cache. SaveToken(token *oauth2.Token) error - // Retrieves the token from the cache. + // GetToken retrieves the token from the cache. GetToken() (*oauth2.Token, error) + + // PurgeIfEquals purges the token from the cache. + PurgeIfEquals(t *oauth2.Token) (bool, error) + + // Lock the cache. + Lock() + + // TryLock tries to lock the cache. + TryLock() bool + + // Unlock the cache. + Unlock() + + // CondWait waits for the condition to be true. + CondWait() + + // CondSignalCondBroadcast signals the condition. + CondBroadcast() } diff --git a/flyteidl/clients/go/admin/cache/token_cache_inmemory.go b/flyteidl/clients/go/admin/cache/token_cache_inmemory.go index 9c6223fc06..ca832ded1f 100644 --- a/flyteidl/clients/go/admin/cache/token_cache_inmemory.go +++ b/flyteidl/clients/go/admin/cache/token_cache_inmemory.go @@ -2,23 +2,93 @@ package cache import ( "fmt" + "sync" + "sync/atomic" "golang.org/x/oauth2" ) type TokenCacheInMemoryProvider struct { - token *oauth2.Token + token atomic.Value + mu *sync.Mutex + condLocker *NoopLocker + cond *sync.Cond } func (t *TokenCacheInMemoryProvider) SaveToken(token *oauth2.Token) error { - t.token = token + t.token.Store(token) return nil } -func (t TokenCacheInMemoryProvider) GetToken() (*oauth2.Token, error) { - if t.token == nil { +func (t *TokenCacheInMemoryProvider) GetToken() (*oauth2.Token, error) { + tkn := t.token.Load() + if tkn == nil { return nil, fmt.Errorf("cannot find token in cache") } + return tkn.(*oauth2.Token), nil +} + +func (t *TokenCacheInMemoryProvider) PurgeIfEquals(existing *oauth2.Token) (bool, error) { + // Add an empty token since we can't mark it nil using Compare and swap + return t.token.CompareAndSwap(existing, &oauth2.Token{}), nil +} + +func (t *TokenCacheInMemoryProvider) Lock() { + t.mu.Lock() +} + +func (t *TokenCacheInMemoryProvider) TryLock() bool { + return t.mu.TryLock() +} + +func (t *TokenCacheInMemoryProvider) Unlock() { + t.mu.Unlock() +} + +// CondWait adds the current go routine to the condition waitlist and waits for another go routine to notify using CondBroadcast +// The current usage is that one who was able to acquire the lock using TryLock is the one who gets a valid token and notifies all the waitlist requesters so that they can use the new valid token. +// It also locks the Locker in the condition variable as the semantics of Wait is that it unlocks the Locker after adding +// the consumer to the waitlist and before blocking on notification. +// We use the condLocker which is noOp locker to get added to waitlist for notifications. +// The underlying notifcationList doesn't need to be guarded as it implementation is atomic and is thread safe +// Refer https://go.dev/src/runtime/sema.go +// Following is the function and its comments +// notifyListAdd adds the caller to a notify list such that it can receive +// notifications. The caller must eventually call notifyListWait to wait for +// such a notification, passing the returned ticket number. +// +// func notifyListAdd(l *notifyList) uint32 { +// // This may be called concurrently, for example, when called from +// // sync.Cond.Wait while holding a RWMutex in read mode. +// return l.wait.Add(1) - 1 +// } +func (t *TokenCacheInMemoryProvider) CondWait() { + t.condLocker.Lock() + t.cond.Wait() + t.condLocker.Unlock() +} + +// NoopLocker has empty implementation of Locker interface +type NoopLocker struct { +} + +func (*NoopLocker) Lock() { + +} +func (*NoopLocker) Unlock() { +} - return t.token, nil +// CondBroadcast signals the condition. +func (t *TokenCacheInMemoryProvider) CondBroadcast() { + t.cond.Broadcast() +} + +func NewTokenCacheInMemoryProvider() *TokenCacheInMemoryProvider { + condLocker := &NoopLocker{} + return &TokenCacheInMemoryProvider{ + mu: &sync.Mutex{}, + token: atomic.Value{}, + condLocker: condLocker, + cond: sync.NewCond(condLocker), + } } diff --git a/flyteidl/clients/go/admin/client_builder.go b/flyteidl/clients/go/admin/client_builder.go index 25b263ecf1..0d1341bf7b 100644 --- a/flyteidl/clients/go/admin/client_builder.go +++ b/flyteidl/clients/go/admin/client_builder.go @@ -40,7 +40,7 @@ func (cb *ClientsetBuilder) WithDialOptions(opts ...grpc.DialOption) *ClientsetB // Build the clientset using the current state of the ClientsetBuilder func (cb *ClientsetBuilder) Build(ctx context.Context) (*Clientset, error) { if cb.tokenCache == nil { - cb.tokenCache = &cache.TokenCacheInMemoryProvider{} + cb.tokenCache = cache.NewTokenCacheInMemoryProvider() } if cb.config == nil { diff --git a/flyteidl/clients/go/admin/client_builder_test.go b/flyteidl/clients/go/admin/client_builder_test.go index c871bcb326..89bcc38550 100644 --- a/flyteidl/clients/go/admin/client_builder_test.go +++ b/flyteidl/clients/go/admin/client_builder_test.go @@ -17,9 +17,9 @@ func TestClientsetBuilder_Build(t *testing.T) { cb := NewClientsetBuilder().WithConfig(&Config{ UseInsecureConnection: true, Endpoint: config.URL{URL: *u}, - }).WithTokenCache(&cache.TokenCacheInMemoryProvider{}) + }).WithTokenCache(cache.NewTokenCacheInMemoryProvider()) ctx := context.Background() _, err := cb.Build(ctx) assert.NoError(t, err) - assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(&cache.TokenCacheInMemoryProvider{})) + assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(cache.NewTokenCacheInMemoryProvider())) } diff --git a/flyteidl/clients/go/admin/client_test.go b/flyteidl/clients/go/admin/client_test.go index eb19b76f47..042a826692 100644 --- a/flyteidl/clients/go/admin/client_test.go +++ b/flyteidl/clients/go/admin/client_test.go @@ -255,6 +255,8 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) { mockAuthClient := new(mocks.AuthMetadataServiceClient) mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil) + mockTokenCache.On("Lock").Return() + mockTokenCache.On("Unlock").Return() mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil) mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil) tokenSourceProvider, err := NewTokenSourceProvider(ctx, adminServiceConfig, mockTokenCache, mockAuthClient) @@ -288,7 +290,7 @@ func Test_getPkceAuthTokenSource(t *testing.T) { assert.NoError(t, err) // populate the cache - tokenCache := &cache.TokenCacheInMemoryProvider{} + tokenCache := cache.NewTokenCacheInMemoryProvider() assert.NoError(t, tokenCache.SaveToken(&tokenData)) baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{ diff --git a/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go b/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go index 5c1dc5f2bd..9f20fb3ef5 100644 --- a/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go +++ b/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go @@ -23,7 +23,7 @@ import ( func TestFetchFromAuthFlow(t *testing.T) { ctx := context.Background() t.Run("fetch from auth flow", func(t *testing.T) { - tokenCache := &cache.TokenCacheInMemoryProvider{} + tokenCache := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ @@ -97,7 +97,7 @@ func TestFetchFromAuthFlow(t *testing.T) { })) defer fakeServer.Close() - tokenCache := &cache.TokenCacheInMemoryProvider{} + tokenCache := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ diff --git a/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go b/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go index dc1c80f63a..ca1973ea66 100644 --- a/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go +++ b/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go @@ -16,7 +16,7 @@ import ( func TestFetchFromAuthFlow(t *testing.T) { ctx := context.Background() t.Run("fetch from auth flow", func(t *testing.T) { - tokenCache := &cache.TokenCacheInMemoryProvider{} + tokenCache := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ diff --git a/flyteidl/clients/go/admin/token_source_provider.go b/flyteidl/clients/go/admin/token_source_provider.go index d4f4a31a5a..83df542082 100644 --- a/flyteidl/clients/go/admin/token_source_provider.go +++ b/flyteidl/clients/go/admin/token_source_provider.go @@ -188,7 +188,7 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s } secret = strings.TrimSpace(secret) if tokenCache == nil { - tokenCache = &cache.TokenCacheInMemoryProvider{} + tokenCache = cache.NewTokenCacheInMemoryProvider() } return ClientCredentialsTokenSourceProvider{ ccConfig: clientcredentials.Config{ @@ -227,14 +227,14 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) { token, err := s.new.Token() if err != nil { - logger.Warnf(s.ctx, "failed to get token: %w", err) + logger.Warnf(s.ctx, "failed to get token: %v", err) return nil, fmt.Errorf("failed to get token: %w", err) } logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry) err = s.tokenCache.SaveToken(token) if err != nil { - logger.Warnf(s.ctx, "failed to cache token: %w", err) + logger.Warnf(s.ctx, "failed to cache token: %v", err) } return token, nil diff --git a/flyteidl/clients/go/admin/token_source_provider_test.go b/flyteidl/clients/go/admin/token_source_provider_test.go index 63fc1aa56e..43d0fdd928 100644 --- a/flyteidl/clients/go/admin/token_source_provider_test.go +++ b/flyteidl/clients/go/admin/token_source_provider_test.go @@ -127,7 +127,9 @@ func TestCustomTokenSource_Token(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { tokenCache := &tokenCacheMocks.TokenCache{} - tokenCache.OnGetToken().Return(test.token, nil).Once() + tokenCache.OnGetToken().Return(test.token, nil).Maybe() + tokenCache.On("Lock").Return().Maybe() + tokenCache.On("Unlock").Return().Maybe() provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") assert.NoError(t, err) source, err := provider.GetTokenSource(ctx) diff --git a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go index c4891b13ae..4fd3fa476c 100644 --- a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go +++ b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go @@ -3,7 +3,6 @@ package tokenorchestrator import ( "context" "fmt" - "time" "golang.org/x/oauth2" @@ -53,16 +52,21 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex return nil, err } - if !token.Valid() { - return nil, fmt.Errorf("token from cache is invalid") + if token.Valid() { + return token, nil } - // If token doesn't need to be refreshed, return it. - if time.Now().Before(token.Expiry.Add(-tokenRefreshGracePeriod.Duration)) { - logger.Infof(ctx, "found the token in the cache") + t.TokenCache.Lock() + defer t.TokenCache.Unlock() + + token, err = t.TokenCache.GetToken() + if err != nil { + return nil, err + } + + if token.Valid() { return token, nil } - token.Expiry = token.Expiry.Add(-tokenRefreshGracePeriod.Duration) token, err = t.RefreshToken(ctx, token) if err != nil { @@ -73,6 +77,8 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex return nil, fmt.Errorf("refreshed token is invalid") } + token.Expiry = token.Expiry.Add(-tokenRefreshGracePeriod.Duration) + err = t.TokenCache.SaveToken(token) if err != nil { return nil, fmt.Errorf("failed to save token in the token cache. Error: %w", err) diff --git a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go index ed4afa0ff0..0a1a9f4985 100644 --- a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go +++ b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go @@ -26,7 +26,7 @@ func TestRefreshTheToken(t *testing.T) { ClientID: "dummyClient", }, } - tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} + tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator := BaseTokenOrchestrator{ ClientConfig: clientConf, TokenCache: tokenCacheProvider, @@ -58,7 +58,7 @@ func TestFetchFromCache(t *testing.T) { mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil) t.Run("no token in cache", func(t *testing.T) { - tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} + tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) @@ -69,7 +69,7 @@ func TestFetchFromCache(t *testing.T) { }) t.Run("token in cache", func(t *testing.T) { - tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} + tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) assert.NoError(t, err) fileData, _ := os.ReadFile("testdata/token.json") @@ -86,7 +86,7 @@ func TestFetchFromCache(t *testing.T) { }) t.Run("expired token in cache", func(t *testing.T) { - tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} + tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) assert.NoError(t, err) fileData, _ := os.ReadFile("testdata/token.json")