Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor api token based login to UAA, fix refresh token logic #824

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions pkg/auth/common/login_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type TanzuLoginHandler struct {
callbackHandlerMutex sync.Mutex
tlsSkipVerify bool
caCertData string
suppressInteractive bool
}

// LoginOption is an optional configuration for Login().
Expand Down Expand Up @@ -133,6 +134,15 @@ func WithClientID(clientID string) LoginOption {
}
}

// WithSuppressInteractive specifies whether to fall back to interactive login if
// an access token cannot be obtained.
func WithSuppressInteractive(suppress bool) LoginOption {
return func(h *TanzuLoginHandler) error {
h.suppressInteractive = suppress
return nil
}
}

// WithListenerPort specifies a TCP listener port on localhost, which will be used for the redirect_uri and to handle the
// authorization code callback. By default, a random high port will be chosen which requires the authorization server
// to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252#section-7.3:
Expand Down Expand Up @@ -166,18 +176,28 @@ func WithListenerPortFromEnv(envVarName string) LoginOption {
}

func (h *TanzuLoginHandler) DoLogin() (*Token, error) {
var err error
var token *Token

if h.refreshToken != "" {
token, err := h.getTokenWithRefreshToken()
if err == nil {
return token, nil
token, err = h.getTokenWithRefreshToken()
if err == nil || h.suppressInteractive {
// non interactive login mode should update the cert map as well
// before returning.
if err == nil && h.suppressInteractive {
h.updateCertMap()
}
return token, err
}
}

// If refresh token fails, proceed with login flow through the browser
return h.browserLogin()
}

func (h *TanzuLoginHandler) getTokenWithRefreshToken() (*Token, error) {
refreshedToken, err := h.oauthConfig.TokenSource(context.TODO(), &oauth2.Token{RefreshToken: h.refreshToken}).Token()
ctx := contextWithCustomTLSConfig(context.TODO(), h.getTLSConfig())
refreshedToken, err := h.oauthConfig.TokenSource(ctx, &oauth2.Token{RefreshToken: h.refreshToken}).Token()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -492,16 +512,20 @@ func GetTLSConfig(endpoint, certData string, skipVerify bool) *tls.Config {
return nil
}

func contextWithCustomTLSConfig(ctx context.Context, tlsConfig *tls.Config) context.Context {
if tlsConfig != nil {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = tlsConfig

sslcli := &http.Client{Transport: tr}
ctx = context.WithValue(ctx, oauth2.HTTPClient, sslcli)
}
return ctx
}

func (h *TanzuLoginHandler) getTokenUsingAuthCode(ctx context.Context, code string) (*oauth2.Token, error) {
if h.idpType == config.UAAIdpType {
tlsConfig := h.getTLSConfig()
if tlsConfig != nil {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = tlsConfig

sslcli := &http.Client{Transport: tr}
ctx = context.WithValue(ctx, oauth2.HTTPClient, sslcli)
}
ctx = contextWithCustomTLSConfig(ctx, h.getTLSConfig())
}

token, err := h.oauthConfig.Exchange(ctx, code, h.pkceCodePair.Verifier())
Expand Down
69 changes: 58 additions & 11 deletions pkg/auth/common/login_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ const (
)

func TestHandleTokenRefresh(t *testing.T) {
assert := assert.New(t)

// Mock HTTP server for token refresh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
Expand All @@ -45,19 +47,64 @@ func TestHandleTokenRefresh(t *testing.T) {
}

token, err := lh.getTokenWithRefreshToken()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if token == nil {
t.Error("Expected a non-nil token, got nil")
assert.Nil(err)
assert.NotNil(token)
assert.Equal(token.AccessToken, "fake-access-token")
assert.Equal(token.RefreshToken, "fake-refresh-token")
assert.Equal(token.TokenType, "id-token")
assert.Equal(token.IDToken, "fake-id-token")
assert.Equal(token.ExpiresIn, int64(3599))
}

// test that login with refresh token completes without triggering browser
// login regardless of whether refresh succeeded or not
func TestLoginWithAPIToken(t *testing.T) {
assert := assert.New(t)

// Mock HTTP server for token refresh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
if strings.Contains(string(body), "refresh_token=valid-api-token") {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token": "fake-access-token", "refresh_token": "fake-refresh-token", "expires_in": 3600, "id_token": "fake-id-token"}`))
return
}
http.Error(w, "refresh_error", http.StatusBadRequest)
}))
defer server.Close()

lh := &TanzuLoginHandler{
oauthConfig: &oauth2.Config{
Endpoint: oauth2.Endpoint{
TokenURL: server.URL,
},
},
refreshToken: "valid-api-token",
suppressInteractive: true,
}
if token != nil {
assert.Equal(t, token.AccessToken, "fake-access-token")
assert.Equal(t, token.RefreshToken, "fake-refresh-token")
assert.Equal(t, token.TokenType, "id-token")
assert.Equal(t, token.IDToken, "fake-id-token")
assert.Equal(t, token.ExpiresIn, int64(3599))
token, err := lh.DoLogin()

assert.Nil(err)
assert.NotNil(token)
assert.Equal(token.AccessToken, "fake-access-token")
assert.Equal(token.RefreshToken, "fake-refresh-token")
assert.Equal(token.TokenType, "id-token")
assert.Equal(token.IDToken, "fake-id-token")
assert.Equal(token.ExpiresIn, int64(3599))

lh = &TanzuLoginHandler{
oauthConfig: &oauth2.Config{
Endpoint: oauth2.Endpoint{
TokenURL: server.URL,
},
},
refreshToken: "bad-refresh-token",
suppressInteractive: true,
}
token, err = lh.DoLogin()
assert.NotNil(err)
assert.Nil(token)
}

func TestGetAuthCodeURL_validResponse(t *testing.T) {
Expand Down
59 changes: 3 additions & 56 deletions pkg/auth/uaa/uaa.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,20 @@
package uaa

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"

"github.com/pkg/errors"

"github.com/vmware-tanzu/tanzu-cli/pkg/auth/common"
"github.com/vmware-tanzu/tanzu-cli/pkg/constants"
"github.com/vmware-tanzu/tanzu-cli/pkg/interfaces"
)

var (
httpRestClient interfaces.HTTPClient
)

// GetAccessTokenFromAPIToken fetches access token using the API-token.
func GetAccessTokenFromAPIToken(apiToken, uaaEndpoint, endpointCACertPath string, skipTLSVerify bool) (*common.Token, error) {
tokenURL := getIssuerEndpoints(uaaEndpoint).TokenURL
data := url.Values{}
data.Set("refresh_token", apiToken)
data.Set("client_id", GetAlternateClientID())
data.Set("grant_type", "refresh_token")

req, _ := http.NewRequestWithContext(context.Background(), "POST", tokenURL, bytes.NewBufferString(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

if httpRestClient == nil {
tlsConfig := common.GetTLSConfig(uaaEndpoint, endpointCACertPath, skipTLSVerify)
if tlsConfig == nil {
return nil, errors.New("unable to set up tls config")
}

tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = tlsConfig
httpRestClient = &http.Client{Transport: tr}
}

resp, err := httpRestClient.Do(req)
if err != nil {
return nil, errors.WithMessage(err, "Failed to obtain access token. Please provide valid API token")
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, errors.Errorf("Failed to obtain access token. Please provide valid API token -- %s", string(body))
}

defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
token := common.Token{}

if err = json.Unmarshal(body, &token); err != nil {
return nil, errors.Wrap(err, "could not unmarshal auth token")
}

return &token, nil
}

// GetTokens fetches the UAA access token
func GetTokens(refreshOrAPIToken, _, issuer, tokenType string) (*common.Token, error) {
clientID := tanzuCLIClientID
if tokenType == common.APITokenType {
clientID = GetAlternateClientID()
}
loginOptions := []common.LoginOption{common.WithRefreshToken(refreshOrAPIToken), common.WithListenerPortFromEnv(constants.TanzuCLIOAuthLocalListenerPort), common.WithClientID(clientID)}
if tokenType == common.APITokenType {
loginOptions = append(loginOptions, common.WithSuppressInteractive(true))
}

token, err := TanzuLogin(issuer, loginOptions...)
if err != nil {
Expand Down
92 changes: 0 additions & 92 deletions pkg/auth/uaa/uaa_test.go

This file was deleted.

23 changes: 21 additions & 2 deletions pkg/command/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,26 @@ func getSelfManagedOrg(c *configtypes.Context) (string, string) {
}

func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiTokenValue string) (claims *commonauth.Claims, err error) {
token, err := uaa.GetAccessTokenFromAPIToken(apiTokenValue, uaaEndpoint, endpointCACertPath, skipTLSVerify)
loginOptions := []commonauth.LoginOption{
commonauth.WithSuppressInteractive(true), // fail instead of falling back to interactive login
commonauth.WithRefreshToken(apiTokenValue),
commonauth.WithClientID(uaa.GetAlternateClientID()),
}

var endpointCACertData string
if endpointCACertPath != "" {
fileBytes, err := os.ReadFile(endpointCACertPath)
if err != nil {
return nil, errors.Wrapf(err, "error reading certificate file %s", endpointCACertPath)
}
endpointCACertData = base64.StdEncoding.EncodeToString(fileBytes)
}
if skipTLSVerify || endpointCACertData != "" {
loginOptions = append(loginOptions, commonauth.WithCertInfo(skipTLSVerify, endpointCACertData))
}

// Invoke TanzuLogin to obtain access token via API token
token, err := uaa.TanzuLogin(uaaEndpoint, loginOptions...)
if err != nil {
return nil, errors.Wrap(err, "failed to get token from UAA")
}
Expand All @@ -748,7 +767,7 @@ func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiT
a.Permissions = claims.Permissions
a.AccessToken = token.AccessToken
a.IDToken = token.IDToken
a.RefreshToken = apiTokenValue
a.RefreshToken = token.RefreshToken
a.Type = commonauth.APITokenType
expiresAt := time.Now().Local().Add(time.Second * time.Duration(token.ExpiresIn))
a.Expiration = expiresAt
Expand Down
Loading