Skip to content

Commit

Permalink
Custom HTTP Client TLS transport for STSWebIdentity (#612)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Valdivia <[email protected]>
  • Loading branch information
Alevsk and dvaldivia authored Feb 25, 2021
1 parent 6ac95e4 commit 9c1f0c4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
22 changes: 13 additions & 9 deletions pkg/auth/idp/oauth2/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ type Provider struct {
// often available via site-specific packages, such as
// google.Endpoint or github.Endpoint.
// - Scopes specifies optional requested permissions.
ClientID string
oauth2Config Configuration
oidcProvider *oidc.Provider
ClientID string
oauth2Config Configuration
oidcProvider *oidc.Provider
provHTTPClient *http.Client
}

// derivedKey is the key used to compute the HMAC for signing the oauth state parameter
Expand All @@ -103,8 +104,9 @@ var derivedKey = pbkdf2.Key([]byte(getPassphraseForIdpHmac()), []byte(getSaltFor
// NewOauth2ProviderClient instantiates a new oauth2 client using the configured credentials
// it returns a *Provider object that contains the necessary configuration to initiate an
// oauth2 authentication flow
func NewOauth2ProviderClient(ctx context.Context, scopes []string) (*Provider, error) {
provider, err := oidc.NewProvider(ctx, GetIdpURL())
func NewOauth2ProviderClient(ctx context.Context, scopes []string, httpClient *http.Client) (*Provider, error) {
customCtx := oidc.ClientContext(ctx, httpClient)
provider, err := oidc.NewProvider(customCtx, GetIdpURL())
if err != nil {
return nil, err
}
Expand All @@ -122,6 +124,7 @@ func NewOauth2ProviderClient(ctx context.Context, scopes []string) (*Provider, e
}
client.oidcProvider = provider
client.ClientID = GetIdpClientID()
client.provHTTPClient = httpClient

return client, nil
}
Expand Down Expand Up @@ -172,10 +175,11 @@ func (client *Provider) VerifyIdentity(ctx context.Context, code, state string)
}, nil
}
stsEndpoint := GetSTSEndpoint()
sts, err := credentials.NewSTSWebIdentity(stsEndpoint, getWebTokenExpiry)
if err != nil {
return nil, err
}
sts := credentials.New(&credentials.STSWebIdentity{
Client: client.provHTTPClient,
STSEndpoint: stsEndpoint,
GetWebIDTokenExpiry: getWebTokenExpiry,
})
return sts, nil
}

Expand Down
4 changes: 2 additions & 2 deletions restapi/user_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func getLoginDetailsResponse() (*models.LoginDetails, *models.Error) {
if oauth2.IsIdpEnabled() {
loginStrategy = models.LoginDetailsLoginStrategyRedirect
// initialize new oauth2 client
oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil)
oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil, GetConsoleSTSClient())
if err != nil {
return nil, prepareError(err)
}
Expand Down Expand Up @@ -235,7 +235,7 @@ func getLoginOauth2AuthResponse(lr *models.LoginOauth2AuthRequest) (*models.Logi
return loginResponse, nil
} else if oauth2.IsIdpEnabled() {
// initialize new oauth2 client
oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil)
oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil, GetConsoleSTSClient())
if err != nil {
return nil, prepareError(err)
}
Expand Down

0 comments on commit 9c1f0c4

Please sign in to comment.