Skip to content

Commit

Permalink
Bring back Retrieve() interface to be compatible and not break clients
Browse files Browse the repository at this point in the history
This was broken in #2041
  • Loading branch information
harshavardhana committed Dec 30, 2024
1 parent 4a691e1 commit 7766336
Show file tree
Hide file tree
Showing 18 changed files with 192 additions and 56 deletions.
16 changes: 13 additions & 3 deletions pkg/credentials/assume_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume
return a, nil
}

// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) {
func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) {
client := m.Client
if client == nil {
client = cc.Client
Expand All @@ -243,3 +241,15 @@ func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) {
SignerType: SignatureV4,
}, nil
}

// RetrieveWithCredContext retrieves credentials from the MinIO service.
// Error will be returned if the request fails, optional cred context.
func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return m.retrieve(cc)
}

// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSAssumeRole) Retrieve() (Value, error) {
return m.retrieve(defaultCredContext)
}
22 changes: 20 additions & 2 deletions pkg/credentials/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,32 @@ func NewChainCredentials(providers []Provider) *Credentials {
})
}

// RetrieveWithCredContext is like Retrieve with CredContext
func (c *Chain) RetrieveWithCredContext(cc *CredContext) (Value, error) {
for _, p := range c.Providers {
creds, _ := p.RetrieveWithCredContext(cc)
// Always prioritize non-anonymous providers, if any.
if creds.AccessKeyID == "" && creds.SecretAccessKey == "" {
continue
}
c.curr = p
return creds, nil
}
// At this point we have exhausted all the providers and
// are left without any credentials return anonymous.
return Value{
SignerType: SignatureAnonymous,
}, nil
}

// Retrieve returns the credentials value, returns no credentials(anonymous)
// if no credentials provider returned any value.
//
// If a provider is found with credentials, it will be cached and any calls
// to IsExpired() will return the expired state of the cached provider.
func (c *Chain) Retrieve(cc *CredContext) (Value, error) {
func (c *Chain) Retrieve() (Value, error) {
for _, p := range c.Providers {
creds, _ := p.Retrieve(cc)
creds, _ := p.RetrieveWithCredContext(defaultCredContext)
// Always prioritize non-anonymous providers, if any.
if creds.AccessKeyID == "" && creds.SecretAccessKey == "" {
continue
Expand Down
15 changes: 10 additions & 5 deletions pkg/credentials/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ type testCredProvider struct {
err error
}

func (s *testCredProvider) Retrieve(_ *CredContext) (Value, error) {
func (s *testCredProvider) Retrieve() (Value, error) {
s.expired = false
return s.creds, s.err
}

func (s *testCredProvider) RetrieveWithCredContext(_ *CredContext) (Value, error) {
s.expired = false
return s.creds, s.err
}
Expand Down Expand Up @@ -59,7 +64,7 @@ func TestChainGet(t *testing.T) {
},
}

creds, err := p.Retrieve(defaultCredContext)
creds, err := p.RetrieveWithCredContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -95,7 +100,7 @@ func TestChainIsExpired(t *testing.T) {
t.Fatal("Expected expired to be true before any Retrieve")
}

_, err := p.Retrieve(defaultCredContext)
_, err := p.RetrieveWithCredContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand All @@ -112,7 +117,7 @@ func TestChainWithNoProvider(t *testing.T) {
if !p.IsExpired() {
t.Fatal("Expected to be expired with no providers")
}
_, err := p.Retrieve(defaultCredContext)
_, err := p.RetrieveWithCredContext(defaultCredContext)
if err != nil {
if err.Error() != "No valid providers found []" {
t.Error(err)
Expand All @@ -136,7 +141,7 @@ func TestChainProviderWithNoValidProvider(t *testing.T) {
t.Fatal("Expected to be expired with no providers")
}

_, err := p.Retrieve(defaultCredContext)
_, err := p.RetrieveWithCredContext(defaultCredContext)
if err != nil {
if err.Error() != "No valid providers found [FirstError SecondError]" {
t.Error(err)
Expand Down
9 changes: 7 additions & 2 deletions pkg/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ type Value struct {
// Value. A provider is required to manage its own Expired state, and what to
// be expired means.
type Provider interface {
// RetrieveWithCredContext returns nil if it successfully retrieved the
// value. Error is returned if the value were not obtainable, or empty.
// optionally takes CredContext for additional context to retrieve credentials.
RetrieveWithCredContext(cc *CredContext) (Value, error)

// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve(cc *CredContext) (Value, error)
Retrieve() (Value, error)

// IsExpired returns if the credentials are no longer valid, and need
// to be retrieved.
Expand Down Expand Up @@ -185,7 +190,7 @@ func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) {
defer c.Unlock()

if c.isExpired() {
creds, err := c.provider.Retrieve(cc)
creds, err := c.provider.RetrieveWithCredContext(cc)
if err != nil {
return Value{}, err
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ type credProvider struct {
err error
}

func (s *credProvider) Retrieve(_ *CredContext) (Value, error) {
func (s *credProvider) Retrieve() (Value, error) {
s.expired = false
return s.creds, s.err
}

func (s *credProvider) RetrieveWithCredContext(_ *CredContext) (Value, error) {
s.expired = false
return s.creds, s.err
}
Expand Down
13 changes: 11 additions & 2 deletions pkg/credentials/env_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ func NewEnvAWS() *Credentials {
return New(&EnvAWS{})
}

// Retrieve retrieves the keys from the environment.
func (e *EnvAWS) Retrieve(_ *CredContext) (Value, error) {
func (e *EnvAWS) retrieve() (Value, error) {
e.retrieved = false

id := os.Getenv("AWS_ACCESS_KEY_ID")
Expand All @@ -65,6 +64,16 @@ func (e *EnvAWS) Retrieve(_ *CredContext) (Value, error) {
}, nil
}

// Retrieve retrieves the keys from the environment.
func (e *EnvAWS) Retrieve() (Value, error) {
return e.retrieve()
}

// RetrieveWithContext is like Retrieve (no-op input of Cred Context)
func (e *EnvAWS) RetrieveWithCredContext(_ *CredContext) (Value, error) {
return e.retrieve()
}

// IsExpired returns if the credentials have been retrieved.
func (e *EnvAWS) IsExpired() bool {
return !e.retrieved
Expand Down
13 changes: 11 additions & 2 deletions pkg/credentials/env_minio.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ func NewEnvMinio() *Credentials {
return New(&EnvMinio{})
}

// Retrieve retrieves the keys from the environment.
func (e *EnvMinio) Retrieve(_ *CredContext) (Value, error) {
func (e *EnvMinio) retrieve() (Value, error) {
e.retrieved = false

id := os.Getenv("MINIO_ROOT_USER")
Expand All @@ -62,6 +61,16 @@ func (e *EnvMinio) Retrieve(_ *CredContext) (Value, error) {
}, nil
}

// Retrieve retrieves the keys from the environment.
func (e *EnvMinio) Retrieve() (Value, error) {
return e.retrieve()
}

// RetrieveWithCredContext is like Retrieve() (no-op input cred context)
func (e *EnvMinio) RetrieveWithCredContext(_ *CredContext) (Value, error) {
return e.retrieve()
}

// IsExpired returns if the credentials have been retrieved.
func (e *EnvMinio) IsExpired() bool {
return !e.retrieved
Expand Down
6 changes: 3 additions & 3 deletions pkg/credentials/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestEnvAWSRetrieve(t *testing.T) {
t.Error("Expect creds to be expired before retrieve.")
}

creds, err := e.Retrieve(defaultCredContext)
creds, err := e.RetrieveWithCredContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestEnvAWSRetrieve(t *testing.T) {
SignerType: SignatureV4,
}

creds, err = e.Retrieve(defaultCredContext)
creds, err = e.RetrieveWithCredContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand All @@ -84,7 +84,7 @@ func TestEnvMinioRetrieve(t *testing.T) {
t.Error("Expect creds to be expired before retrieve.")
}

creds, err := e.Retrieve(defaultCredContext)
creds, err := e.RetrieveWithCredContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down
15 changes: 12 additions & 3 deletions pkg/credentials/file_aws_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ func NewFileAWSCredentials(filename, profile string) *Credentials {
})
}

// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *FileAWSCredentials) Retrieve(_ *CredContext) (Value, error) {
func (p *FileAWSCredentials) retrieve() (Value, error) {
if p.Filename == "" {
p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
if p.Filename == "" {
Expand Down Expand Up @@ -142,6 +140,17 @@ func (p *FileAWSCredentials) Retrieve(_ *CredContext) (Value, error) {
}, nil
}

// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *FileAWSCredentials) Retrieve() (Value, error) {
return p.retrieve()
}

// RetrieveWithCredContext is like Retrieve(), cred context is no-op for File credentials
func (p *FileAWSCredentials) RetrieveWithCredContext(_ *CredContext) (Value, error) {
return p.retrieve()
}

// loadProfiles loads from the file pointed to by shared credentials filename for profile.
// The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid.
Expand Down
15 changes: 12 additions & 3 deletions pkg/credentials/file_minio_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ func NewFileMinioClient(filename, alias string) *Credentials {
})
}

// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *FileMinioClient) Retrieve(_ *CredContext) (Value, error) {
func (p *FileMinioClient) retrieve() (Value, error) {
if p.Filename == "" {
if value, ok := os.LookupEnv("MINIO_SHARED_CREDENTIALS_FILE"); ok {
p.Filename = value
Expand Down Expand Up @@ -96,6 +94,17 @@ func (p *FileMinioClient) Retrieve(_ *CredContext) (Value, error) {
}, nil
}

// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *FileMinioClient) Retrieve() (Value, error) {
return p.retrieve()
}

// RetrieveWithCredContext - is like Retrieve()
func (p *FileMinioClient) RetrieveWithCredContext(_ *CredContext) (Value, error) {
return p.retrieve()
}

// IsExpired returns if the shared credentials have expired.
func (p *FileMinioClient) IsExpired() bool {
return !p.retrieved
Expand Down
19 changes: 14 additions & 5 deletions pkg/credentials/iam_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,7 @@ func NewIAM(endpoint string) *Credentials {
})
}

// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired
func (m *IAM) Retrieve(cc *CredContext) (Value, error) {
func (m *IAM) retrieve(cc *CredContext) (Value, error) {
token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN")
if token == "" {
token = m.Container.AuthorizationToken
Expand Down Expand Up @@ -177,7 +174,7 @@ func (m *IAM) Retrieve(cc *CredContext) (Value, error) {
roleSessionName: roleSessionName,
}

stsWebIdentityCreds, err := creds.Retrieve(cc)
stsWebIdentityCreds, err := creds.RetrieveWithCredContext(cc)
if err == nil {
m.SetExpiration(creds.Expiration(), DefaultExpiryWindow)
}
Expand Down Expand Up @@ -227,6 +224,18 @@ func (m *IAM) Retrieve(cc *CredContext) (Value, error) {
}, nil
}

// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired
func (m *IAM) Retrieve() (Value, error) {
return m.retrieve(defaultCredContext)
}

// RetrieveWithCredContext is like Retrieve with Cred Context
func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return m.retrieve(cc)
}

// A ec2RoleCredRespBody provides the shape for unmarshaling credential
// request responses.
type ec2RoleCredRespBody struct {
Expand Down
Loading

0 comments on commit 7766336

Please sign in to comment.