diff --git a/api/attestationconfigapi/fetcher.go b/api/attestationconfigapi/fetcher.go index 362c41bf58..c1e25ed022 100644 --- a/api/attestationconfigapi/fetcher.go +++ b/api/attestationconfigapi/fetcher.go @@ -18,9 +18,31 @@ import ( const cosignPublicKey = constants.CosignPublicKeyReleases +var ( + // AzureSEVSNP is the Azure SEV-SNP variant. + AzureSEVSNP Variant = attestationVariant{variant: variant.AzureSEVSNP{}} + // AWSSEVSNP is the AWS SEV-SNP variant. + AWSSEVSNP Variant = attestationVariant{variant: variant.AWSSEVSNP{}} + // GCPSEVSNP is the GCP SEV-SNP variant. + GCPSEVSNP Variant = attestationVariant{variant: variant.GCPSEVSNP{}} +) + +type attestationVariant struct { + variant Variant +} + +func (v attestationVariant) String() string { + return v.variant.String() +} + +// Variant is a cloud provider specific attestation variant. +type Variant interface { + String() string +} + // Fetcher fetches config API resources without authentication. type Fetcher interface { - FetchLatestVersion(ctx context.Context, attestation fmt.Stringer) (Entry, error) + FetchLatestVersion(ctx context.Context, attestation Variant) (Entry, error) } // fetcher fetches AttestationCfg API resources without authentication. @@ -60,7 +82,7 @@ func newFetcherWithClientAndVerifier(client apifetcher.HTTPClient, cosignVerifie } // FetchLatestVersion returns the latest versions of the given type. -func (f *fetcher) FetchLatestVersion(ctx context.Context, variant fmt.Stringer) (Entry, error) { +func (f *fetcher) FetchLatestVersion(ctx context.Context, variant Variant) (Entry, error) { list, err := f.fetchVersionList(ctx, variant) if err != nil { return Entry{}, err @@ -71,7 +93,7 @@ func (f *fetcher) FetchLatestVersion(ctx context.Context, variant fmt.Stringer) } // fetchVersionList fetches the version list information from the config API. -func (f *fetcher) fetchVersionList(ctx context.Context, attestationVariant fmt.Stringer) (List, error) { +func (f *fetcher) fetchVersionList(ctx context.Context, attestationVariant Variant) (List, error) { parsedVariant, err := variant.FromString(attestationVariant.String()) if err != nil { return List{}, fmt.Errorf("parsing variant: %w", err) @@ -88,7 +110,7 @@ func (f *fetcher) fetchVersionList(ctx context.Context, attestationVariant fmt.S } // fetchVersion fetches the version information from the config API. -func (f *fetcher) fetchVersion(ctx context.Context, version string, attestationVariant fmt.Stringer) (Entry, error) { +func (f *fetcher) fetchVersion(ctx context.Context, version string, attestationVariant Variant) (Entry, error) { parsedVariant, err := variant.FromString(attestationVariant.String()) if err != nil { return Entry{}, fmt.Errorf("parsing variant: %w", err) diff --git a/cli/internal/cmd/configfetchmeasurements_test.go b/cli/internal/cmd/configfetchmeasurements_test.go index 84f99ee962..33becae925 100644 --- a/cli/internal/cmd/configfetchmeasurements_test.go +++ b/cli/internal/cmd/configfetchmeasurements_test.go @@ -8,7 +8,6 @@ package cmd import ( "context" - "fmt" "net/http" "net/url" "testing" @@ -205,7 +204,7 @@ func (f stubVerifyFetcher) FetchAndVerifyMeasurements(_ context.Context, _ strin type stubAttestationFetcher struct{} -func (f stubAttestationFetcher) FetchLatestVersion(_ context.Context, _ fmt.Stringer) (attestationconfigapi.Entry, error) { +func (f stubAttestationFetcher) FetchLatestVersion(_ context.Context, _ attestationconfigapi.Variant) (attestationconfigapi.Entry, error) { return attestationconfigapi.Entry{ SEVSNPVersion: testCfg, }, nil diff --git a/cli/internal/cmd/iamupgradeapply_test.go b/cli/internal/cmd/iamupgradeapply_test.go index 8e11b5018e..2e62e6cdcb 100644 --- a/cli/internal/cmd/iamupgradeapply_test.go +++ b/cli/internal/cmd/iamupgradeapply_test.go @@ -7,7 +7,6 @@ package cmd import ( "context" - "fmt" "io" "path/filepath" "strings" @@ -171,6 +170,6 @@ type stubConfigFetcher struct { fetchLatestErr error } -func (s *stubConfigFetcher) FetchLatestVersion(context.Context, fmt.Stringer) (attestationconfigapi.Entry, error) { +func (s *stubConfigFetcher) FetchLatestVersion(context.Context, attestationconfigapi.Variant) (attestationconfigapi.Entry, error) { return attestationconfigapi.Entry{}, s.fetchLatestErr } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 279803b2b0..acb2d5f54a 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -9,7 +9,6 @@ package config import ( "context" "errors" - "fmt" "reflect" "testing" @@ -1052,7 +1051,7 @@ func getConfigAsMap(conf *Config, t *testing.T) (res configMap) { type stubAttestationFetcher struct{} -func (f stubAttestationFetcher) FetchLatestVersion(_ context.Context, _ fmt.Stringer) (attestationconfigapi.Entry, error) { +func (f stubAttestationFetcher) FetchLatestVersion(_ context.Context, _ attestationconfigapi.Variant) (attestationconfigapi.Entry, error) { return attestationconfigapi.Entry{ SEVSNPVersion: testCfg, }, nil