From 6baaf005ce6cd58f8915167bc87f0511e8a9bc58 Mon Sep 17 00:00:00 2001 From: Vui Lam Date: Mon, 7 Oct 2024 23:43:18 -0700 Subject: [PATCH] Simplify auth with api token against uaa Refactored login handler so that we can leverage it to perform the token refresh using provided API token. Also: - ensures that the API token based login updates CLI Context with the refresh token obtained from UAA. - do not interactively login on expired refresh token for api-token type tokens. Signed-off-by: Vui Lam --- pkg/auth/common/login_handler.go | 24 +++++-- pkg/auth/common/login_handler_test.go | 71 +++++++++++++++++---- pkg/auth/uaa/uaa.go | 59 +---------------- pkg/auth/uaa/uaa_test.go | 92 --------------------------- pkg/command/context.go | 23 ++++++- 5 files changed, 102 insertions(+), 167 deletions(-) delete mode 100644 pkg/auth/uaa/uaa_test.go diff --git a/pkg/auth/common/login_handler.go b/pkg/auth/common/login_handler.go index 0294ea728..9cf1f34f9 100644 --- a/pkg/auth/common/login_handler.go +++ b/pkg/auth/common/login_handler.go @@ -60,6 +60,7 @@ type TanzuLoginHandler struct { callbackHandlerMutex sync.Mutex tlsSkipVerify bool caCertData string + suppressInteractive bool } // LoginOption is an optional configuration for Login(). @@ -133,6 +134,15 @@ func WithClientID(clientID string) LoginOption { } } +// WithSuppressInteractive specifies whether to fall back to interactive login if +// an access token cannot be obtained. +func WithSuppressInteractive(suppress bool) LoginOption { + return func(h *TanzuLoginHandler) error { + h.suppressInteractive = suppress + return nil + } +} + // WithListenerPort specifies a TCP listener port on localhost, which will be used for the redirect_uri and to handle the // authorization code callback. By default, a random high port will be chosen which requires the authorization server // to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252#section-7.3: @@ -166,18 +176,22 @@ func WithListenerPortFromEnv(envVarName string) LoginOption { } func (h *TanzuLoginHandler) DoLogin() (*Token, error) { + var err error + var token *Token + if h.refreshToken != "" { - ctx := contextWithCustomTLSConfig(context.TODO(), h.getTLSConfig()) - token, err := h.getTokenWithRefreshToken(ctx) - if err == nil { - return token, nil + token, err = h.getTokenWithRefreshToken() + if err == nil || h.suppressInteractive { + return token, err } } + // If refresh token fails, proceed with login flow through the browser return h.browserLogin() } -func (h *TanzuLoginHandler) getTokenWithRefreshToken(ctx context.Context) (*Token, error) { +func (h *TanzuLoginHandler) getTokenWithRefreshToken() (*Token, error) { + ctx := contextWithCustomTLSConfig(context.TODO(), h.getTLSConfig()) refreshedToken, err := h.oauthConfig.TokenSource(ctx, &oauth2.Token{RefreshToken: h.refreshToken}).Token() if err != nil { return nil, err diff --git a/pkg/auth/common/login_handler_test.go b/pkg/auth/common/login_handler_test.go index 6c96df264..5ad6d0ecb 100644 --- a/pkg/auth/common/login_handler_test.go +++ b/pkg/auth/common/login_handler_test.go @@ -27,6 +27,8 @@ const ( ) func TestHandleTokenRefresh(t *testing.T) { + assert := assert.New(t) + // Mock HTTP server for token refresh server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -44,20 +46,65 @@ func TestHandleTokenRefresh(t *testing.T) { refreshToken: "fake-refresh-token", } - token, err := lh.getTokenWithRefreshToken(context.TODO()) - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - if token == nil { - t.Error("Expected a non-nil token, got nil") + token, err := lh.getTokenWithRefreshToken() + assert.Nil(err) + assert.NotNil(token) + assert.Equal(token.AccessToken, "fake-access-token") + assert.Equal(token.RefreshToken, "fake-refresh-token") + assert.Equal(token.TokenType, "id-token") + assert.Equal(token.IDToken, "fake-id-token") + assert.Equal(token.ExpiresIn, int64(3599)) +} + +// test that login with refresh token completes without triggering browser +// login regardless of whether refresh succeeded or not +func TestLoginWithAPIToken(t *testing.T) { + assert := assert.New(t) + + // Mock HTTP server for token refresh + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + body, _ := io.ReadAll(r.Body) + if strings.Contains(string(body), "refresh_token=valid-api-token") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token": "fake-access-token", "refresh_token": "fake-refresh-token", "expires_in": 3600, "id_token": "fake-id-token"}`)) + return + } + http.Error(w, "refresh_error", http.StatusBadRequest) + })) + defer server.Close() + + lh := &TanzuLoginHandler{ + oauthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: server.URL, + }, + }, + refreshToken: "valid-api-token", + suppressInteractive: true, } - if token != nil { - assert.Equal(t, token.AccessToken, "fake-access-token") - assert.Equal(t, token.RefreshToken, "fake-refresh-token") - assert.Equal(t, token.TokenType, "id-token") - assert.Equal(t, token.IDToken, "fake-id-token") - assert.Equal(t, token.ExpiresIn, int64(3599)) + token, err := lh.DoLogin() + + assert.Nil(err) + assert.NotNil(token) + assert.Equal(token.AccessToken, "fake-access-token") + assert.Equal(token.RefreshToken, "fake-refresh-token") + assert.Equal(token.TokenType, "id-token") + assert.Equal(token.IDToken, "fake-id-token") + assert.Equal(token.ExpiresIn, int64(3599)) + + lh = &TanzuLoginHandler{ + oauthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: server.URL, + }, + }, + refreshToken: "bad-refresh-token", + suppressInteractive: true, } + token, err = lh.DoLogin() + assert.NotNil(err) + assert.Nil(token) } func TestGetAuthCodeURL_validResponse(t *testing.T) { diff --git a/pkg/auth/uaa/uaa.go b/pkg/auth/uaa/uaa.go index 500c14fd5..eba4b91e2 100644 --- a/pkg/auth/uaa/uaa.go +++ b/pkg/auth/uaa/uaa.go @@ -5,66 +5,10 @@ package uaa import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/url" - - "github.com/pkg/errors" - "github.com/vmware-tanzu/tanzu-cli/pkg/auth/common" "github.com/vmware-tanzu/tanzu-cli/pkg/constants" - "github.com/vmware-tanzu/tanzu-cli/pkg/interfaces" ) -var ( - httpRestClient interfaces.HTTPClient -) - -// GetAccessTokenFromAPIToken fetches access token using the API-token. -func GetAccessTokenFromAPIToken(apiToken, uaaEndpoint, endpointCACertPath string, skipTLSVerify bool) (*common.Token, error) { - tokenURL := getIssuerEndpoints(uaaEndpoint).TokenURL - data := url.Values{} - data.Set("refresh_token", apiToken) - data.Set("client_id", GetAlternateClientID()) - data.Set("grant_type", "refresh_token") - - req, _ := http.NewRequestWithContext(context.Background(), "POST", tokenURL, bytes.NewBufferString(data.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - if httpRestClient == nil { - tlsConfig := common.GetTLSConfig(uaaEndpoint, endpointCACertPath, skipTLSVerify) - if tlsConfig == nil { - return nil, errors.New("unable to set up tls config") - } - - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = tlsConfig - httpRestClient = &http.Client{Transport: tr} - } - - resp, err := httpRestClient.Do(req) - if err != nil { - return nil, errors.WithMessage(err, "Failed to obtain access token. Please provide valid API token") - } - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, errors.Errorf("Failed to obtain access token. Please provide valid API token -- %s", string(body)) - } - - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - token := common.Token{} - - if err = json.Unmarshal(body, &token); err != nil { - return nil, errors.Wrap(err, "could not unmarshal auth token") - } - - return &token, nil -} - // GetTokens fetches the UAA access token func GetTokens(refreshOrAPIToken, _, issuer, tokenType string) (*common.Token, error) { clientID := tanzuCLIClientID @@ -72,6 +16,9 @@ func GetTokens(refreshOrAPIToken, _, issuer, tokenType string) (*common.Token, e clientID = GetAlternateClientID() } loginOptions := []common.LoginOption{common.WithRefreshToken(refreshOrAPIToken), common.WithListenerPortFromEnv(constants.TanzuCLIOAuthLocalListenerPort), common.WithClientID(clientID)} + if tokenType == common.APITokenType { + loginOptions = append(loginOptions, common.WithSuppressInteractive(true)) + } token, err := TanzuLogin(issuer, loginOptions...) if err != nil { diff --git a/pkg/auth/uaa/uaa_test.go b/pkg/auth/uaa/uaa_test.go deleted file mode 100644 index 1a37edd56..000000000 --- a/pkg/auth/uaa/uaa_test.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2023 VMware, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package uaa - -import ( - "bytes" - "fmt" - "io" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/vmware-tanzu/tanzu-cli/pkg/fakes" -) - -const ( - fakeIssuerURL = "https://auth0.com/" - fakeAPIToken = "fake_api_token" - fakeCACrtPath = "/fake/ca.crt" - fakeSkipVerify = false -) - -func TestGetAccessTokenFromAPIToken(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(`{ - "id_token": "abc", - "token_type": "Test", - "expires_in": 86400, - "scope": "Test", - "access_token": "LetMeIn", - "refresh_token": "LetMeInAgain"}`))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 200, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - if err != nil { - fmt.Println(err) - fmt.Println("Error...................................") - } - assert.Nil(err) - assert.Equal("LetMeIn", token.AccessToken) - - req := fakeHTTPClient.DoArgsForCall(0) - bodyBytes, _ := io.ReadAll(req.Body) - body := string(bodyBytes) - - assert.Contains(body, "refresh_token="+fakeAPIToken) - assert.Contains(body, "client_id="+GetAlternateClientID()) - assert.Contains(body, "grant_type=refresh_token") -} - -func TestGetAccessTokenFromAPIToken_FailStatus(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(``))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 403, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - assert.NotNil(err) - assert.Contains(err.Error(), "Failed to obtain access token. Please provide valid API token") - assert.Nil(token) -} - -func TestGetAccessTokenFromAPIToken_InvalidResponse(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(`[{ - "id_token": "abc", - "token_type": "Test", - "expires_in": 86400, - "scope": "Test", - "access_token": "LetMeIn", - "refresh_token": "LetMeInAgain"}]`))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 200, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - assert.NotNil(err) - assert.Contains(err.Error(), "could not unmarshal") - assert.Nil(token) -} diff --git a/pkg/command/context.go b/pkg/command/context.go index 9cf39882b..6bae0224f 100644 --- a/pkg/command/context.go +++ b/pkg/command/context.go @@ -733,7 +733,26 @@ func getSelfManagedOrg(c *configtypes.Context) (string, string) { } func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiTokenValue string) (claims *commonauth.Claims, err error) { - token, err := uaa.GetAccessTokenFromAPIToken(apiTokenValue, uaaEndpoint, endpointCACertPath, skipTLSVerify) + loginOptions := []commonauth.LoginOption{ + commonauth.WithSuppressInteractive(true), // fail instead of falling back to interactive login + commonauth.WithRefreshToken(apiTokenValue), + commonauth.WithClientID(uaa.GetAlternateClientID()), + } + + var endpointCACertData string + if endpointCACertPath != "" { + fileBytes, err := os.ReadFile(endpointCACertPath) + if err != nil { + return nil, errors.Wrapf(err, "error reading certificate file %s", endpointCACertPath) + } + endpointCACertData = base64.StdEncoding.EncodeToString(fileBytes) + } + if skipTLSVerify || endpointCACertData != "" { + loginOptions = append(loginOptions, commonauth.WithCertInfo(skipTLSVerify, endpointCACertData)) + } + + // Invoke TanzuLogin to obtain access token via API token + token, err := uaa.TanzuLogin(uaaEndpoint, loginOptions...) if err != nil { return nil, errors.Wrap(err, "failed to get token from UAA") } @@ -748,7 +767,7 @@ func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiT a.Permissions = claims.Permissions a.AccessToken = token.AccessToken a.IDToken = token.IDToken - a.RefreshToken = apiTokenValue + a.RefreshToken = token.RefreshToken a.Type = commonauth.APITokenType expiresAt := time.Now().Local().Add(time.Second * time.Duration(token.ExpiresIn)) a.Expiration = expiresAt