Skip to content

Commit

Permalink
Use proper HTTP client for fetching credentials (#2041)
Browse files Browse the repository at this point in the history
* Use proper HTTP client for fetching credentials
* Allow custom `http.Client` in credential providers.
  • Loading branch information
ramondeklein authored Dec 27, 2024
1 parent 8dc4193 commit 4a691e1
Show file tree
Hide file tree
Showing 24 changed files with 185 additions and 130 deletions.
2 changes: 1 addition & 1 deletion api-presigned.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (c *Client) PresignedPostPolicy(ctx context.Context, p *PostPolicy) (u *url
}

// Get credentials from the configured credentials provider.
credValues, err := c.credsProvider.Get()
credValues, err := c.credsProvider.GetWithContext(c.CredContext())
if err != nil {
return nil, nil, err
}
Expand Down
19 changes: 15 additions & 4 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,9 @@ func (c *Client) executeMethod(ctx context.Context, method string, metadata requ
return nil, errors.New(c.endpointURL.String() + " is offline.")
}

var retryable bool // Indicates if request can be retried.
var bodySeeker io.Seeker // Extracted seeker from io.Reader.
var reqRetry = c.maxRetries // Indicates how many times we can retry the request
var retryable bool // Indicates if request can be retried.
var bodySeeker io.Seeker // Extracted seeker from io.Reader.
reqRetry := c.maxRetries // Indicates how many times we can retry the request

if metadata.contentBody != nil {
// Check if body is seekable then it is retryable.
Expand Down Expand Up @@ -808,7 +808,7 @@ func (c *Client) newRequest(ctx context.Context, method string, metadata request
}

// Get credentials from the configured credentials provider.
value, err := c.credsProvider.Get()
value, err := c.credsProvider.GetWithContext(c.CredContext())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1018,3 +1018,14 @@ func (c *Client) isVirtualHostStyleRequest(url url.URL, bucketName string) bool
// path style requests
return s3utils.IsVirtualHostSupported(url, bucketName)
}

// CredContext returns the context for fetching credentials
func (c *Client) CredContext() *credentials.CredContext {
httpClient := c.httpClient
if httpClient == nil {
httpClient = http.DefaultClient
}
return &credentials.CredContext{
Client: httpClient,
}
}
2 changes: 1 addition & 1 deletion bucket-cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (c *Client) getBucketLocationRequest(ctx context.Context, bucketName string
c.setUserAgent(req)

// Get credentials from the configured credentials provider.
value, err := c.credsProvider.Get()
value, err := c.credsProvider.GetWithContext(c.CredContext())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion bucket-cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestGetBucketLocationRequest(t *testing.T) {
c.setUserAgent(req)

// Get credentials from the configured credentials provider.
value, err := c.credsProvider.Get()
value, err := c.credsProvider.GetWithContext(c.CredContext())
if err != nil {
return nil, err
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/credentials/assume_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ type AssumeRoleResult struct {
type STSAssumeRole struct {
Expiry

// Required http Client to use when connecting to MinIO STS service.
// Optional http Client to use when connecting to MinIO STS service
// (overrides default client in CredContext)
Client *http.Client

// STS endpoint to fetch STS credentials.
Expand Down Expand Up @@ -115,9 +116,6 @@ func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentia
return nil, errors.New("AssumeRole credentials access/secretkey is mandatory")
}
return New(&STSAssumeRole{
Client: &http.Client{
Transport: http.DefaultTransport,
},
STSEndpoint: stsEndpoint,
Options: opts,
}), nil
Expand Down Expand Up @@ -224,8 +222,12 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume

// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSAssumeRole) Retrieve() (Value, error) {
a, err := getAssumeRoleCredentials(m.Client, m.STSEndpoint, m.Options)
func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) {
client := m.Client
if client == nil {
client = cc.Client
}
a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options)
if err != nil {
return Value{}, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/credentials/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func NewChainCredentials(providers []Provider) *Credentials {
//
// If a provider is found with credentials, it will be cached and any calls
// to IsExpired() will return the expired state of the cached provider.
func (c *Chain) Retrieve() (Value, error) {
func (c *Chain) Retrieve(cc *CredContext) (Value, error) {
for _, p := range c.Providers {
creds, _ := p.Retrieve()
creds, _ := p.Retrieve(cc)
// Always prioritize non-anonymous providers, if any.
if creds.AccessKeyID == "" && creds.SecretAccessKey == "" {
continue
Expand Down
10 changes: 5 additions & 5 deletions pkg/credentials/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type testCredProvider struct {
err error
}

func (s *testCredProvider) Retrieve() (Value, error) {
func (s *testCredProvider) Retrieve(_ *CredContext) (Value, error) {
s.expired = false
return s.creds, s.err
}
Expand Down Expand Up @@ -59,7 +59,7 @@ func TestChainGet(t *testing.T) {
},
}

creds, err := p.Retrieve()
creds, err := p.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func TestChainIsExpired(t *testing.T) {
t.Fatal("Expected expired to be true before any Retrieve")
}

_, err := p.Retrieve()
_, err := p.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand All @@ -112,7 +112,7 @@ func TestChainWithNoProvider(t *testing.T) {
if !p.IsExpired() {
t.Fatal("Expected to be expired with no providers")
}
_, err := p.Retrieve()
_, err := p.Retrieve(defaultCredContext)
if err != nil {
if err.Error() != "No valid providers found []" {
t.Error(err)
Expand All @@ -136,7 +136,7 @@ func TestChainProviderWithNoValidProvider(t *testing.T) {
t.Fatal("Expected to be expired with no providers")
}

_, err := p.Retrieve()
_, err := p.Retrieve(defaultCredContext)
if err != nil {
if err.Error() != "No valid providers found [FirstError SecondError]" {
t.Error(err)
Expand Down
34 changes: 32 additions & 2 deletions pkg/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package credentials

import (
"net/http"
"sync"
"time"
)
Expand All @@ -30,6 +31,10 @@ const (
defaultExpiryWindow = 0.8
)

// defaultCredContext is used when the credential context doesn't
// actually matter or the default context is suitable.
var defaultCredContext = &CredContext{Client: http.DefaultClient}

// A Value is the S3 credentials value for individual credential fields.
type Value struct {
// S3 Access key ID
Expand All @@ -54,13 +59,21 @@ type Value struct {
type Provider interface {
// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve() (Value, error)
Retrieve(cc *CredContext) (Value, error)

// IsExpired returns if the credentials are no longer valid, and need
// to be retrieved.
IsExpired() bool
}

// CredContext is passed to the Retrieve function of a provider to provide
// some additional context to retrieve credentials.
type CredContext struct {
// Client specifies the HTTP client that should be used if an HTTP
// request is to be made to fetch the credentials.
Client *http.Client
}

// A Expiry provides shared expiration logic to be used by credentials
// providers to implement expiry functionality.
//
Expand Down Expand Up @@ -146,7 +159,24 @@ func New(provider Provider) *Credentials {
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
//
// Deprecated: Get() exists for historical compatibility and should not be
// used. To get new credentials use the Credentials.GetWithContext function
// to ensure the proper context (i.e. HTTP client) will be used.
func (c *Credentials) Get() (Value, error) {
return c.GetWithContext(defaultCredContext)
}

// GetWithContext returns the credentials value, or error if the
// credentials Value failed to be retrieved.
//
// Will return the cached credentials Value if it has not expired. If the
// credentials Value has expired the Provider's Retrieve() will be called
// to refresh the credentials.
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) {
if c == nil {
return Value{}, nil
}
Expand All @@ -155,7 +185,7 @@ func (c *Credentials) Get() (Value, error) {
defer c.Unlock()

if c.isExpired() {
creds, err := c.provider.Retrieve()
creds, err := c.provider.Retrieve(cc)
if err != nil {
return Value{}, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type credProvider struct {
err error
}

func (s *credProvider) Retrieve() (Value, error) {
func (s *credProvider) Retrieve(_ *CredContext) (Value, error) {
s.expired = false
return s.creds, s.err
}
Expand All @@ -47,7 +47,7 @@ func TestCredentialsGet(t *testing.T) {
expired: true,
})

creds, err := c.Get()
creds, err := c.GetWithContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand All @@ -65,7 +65,7 @@ func TestCredentialsGet(t *testing.T) {
func TestCredentialsGetWithError(t *testing.T) {
c := New(&credProvider{err: errors.New("Custom error")})

_, err := c.Get()
_, err := c.GetWithContext(defaultCredContext)
if err != nil {
if err.Error() != "Custom error" {
t.Errorf("Expected \"Custom error\", got %s", err.Error())
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/env_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func NewEnvAWS() *Credentials {
}

// Retrieve retrieves the keys from the environment.
func (e *EnvAWS) Retrieve() (Value, error) {
func (e *EnvAWS) Retrieve(_ *CredContext) (Value, error) {
e.retrieved = false

id := os.Getenv("AWS_ACCESS_KEY_ID")
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/env_minio.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func NewEnvMinio() *Credentials {
}

// Retrieve retrieves the keys from the environment.
func (e *EnvMinio) Retrieve() (Value, error) {
func (e *EnvMinio) Retrieve(_ *CredContext) (Value, error) {
e.retrieved = false

id := os.Getenv("MINIO_ROOT_USER")
Expand Down
6 changes: 3 additions & 3 deletions pkg/credentials/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestEnvAWSRetrieve(t *testing.T) {
t.Error("Expect creds to be expired before retrieve.")
}

creds, err := e.Retrieve()
creds, err := e.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestEnvAWSRetrieve(t *testing.T) {
SignerType: SignatureV4,
}

creds, err = e.Retrieve()
creds, err = e.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand All @@ -84,7 +84,7 @@ func TestEnvMinioRetrieve(t *testing.T) {
t.Error("Expect creds to be expired before retrieve.")
}

creds, err := e.Retrieve()
creds, err := e.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/file_aws_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewFileAWSCredentials(filename, profile string) *Credentials {

// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *FileAWSCredentials) Retrieve() (Value, error) {
func (p *FileAWSCredentials) Retrieve(_ *CredContext) (Value, error) {
if p.Filename == "" {
p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
if p.Filename == "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/file_minio_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func NewFileMinioClient(filename, alias string) *Credentials {

// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *FileMinioClient) Retrieve() (Value, error) {
func (p *FileMinioClient) Retrieve(_ *CredContext) (Value, error) {
if p.Filename == "" {
if value, ok := os.LookupEnv("MINIO_SHARED_CREDENTIALS_FILE"); ok {
p.Filename = value
Expand Down
Loading

0 comments on commit 4a691e1

Please sign in to comment.