From 20269ab46e1d2e6e591bc4e1431a186c53c5d9bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Wed, 3 Jul 2024 14:41:29 +0200 Subject: [PATCH] gcp: pass context to metadata functions (#3228) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- internal/attestation/gcp/es/issuer_test.go | 6 ++--- internal/attestation/gcp/metadata.go | 28 +++++++++++----------- internal/attestation/gcp/snp/issuer.go | 12 +++++----- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/internal/attestation/gcp/es/issuer_test.go b/internal/attestation/gcp/es/issuer_test.go index 4836628553..d8d0075de7 100644 --- a/internal/attestation/gcp/es/issuer_test.go +++ b/internal/attestation/gcp/es/issuer_test.go @@ -91,14 +91,14 @@ type fakeMetadataClient struct { zoneErr error } -func (c fakeMetadataClient) ProjectID() (string, error) { +func (c fakeMetadataClient) ProjectID(_ context.Context) (string, error) { return c.projectIDString, c.projecIDErr } -func (c fakeMetadataClient) InstanceName() (string, error) { +func (c fakeMetadataClient) InstanceName(_ context.Context) (string, error) { return c.instanceNameString, c.instanceNameErr } -func (c fakeMetadataClient) Zone() (string, error) { +func (c fakeMetadataClient) Zone(_ context.Context) (string, error) { return c.zoneString, c.zoneErr } diff --git a/internal/attestation/gcp/metadata.go b/internal/attestation/gcp/metadata.go index d0e32eb8a0..471eceb99e 100644 --- a/internal/attestation/gcp/metadata.go +++ b/internal/attestation/gcp/metadata.go @@ -21,17 +21,17 @@ func GCEInstanceInfo(client gcpMetadataClient) func(context.Context, io.ReadWrit // Ideally we would want to use the endorsement public key certificate // However, this is not available on GCE instances // Workaround: Provide ShieldedVM instance info - // The attestating party can request the VMs signing key using Google's API - return func(context.Context, io.ReadWriteCloser, []byte) ([]byte, error) { - projectID, err := client.ProjectID() + // The attesting party can request the VMs signing key using Google's API + return func(ctx context.Context, _ io.ReadWriteCloser, _ []byte) ([]byte, error) { + projectID, err := client.ProjectID(ctx) if err != nil { return nil, errors.New("unable to fetch projectID") } - zone, err := client.Zone() + zone, err := client.Zone(ctx) if err != nil { return nil, errors.New("unable to fetch zone") } - instanceName, err := client.InstanceName() + instanceName, err := client.InstanceName(ctx) if err != nil { return nil, errors.New("unable to fetch instance name") } @@ -45,25 +45,25 @@ func GCEInstanceInfo(client gcpMetadataClient) func(context.Context, io.ReadWrit } type gcpMetadataClient interface { - ProjectID() (string, error) - InstanceName() (string, error) - Zone() (string, error) + ProjectID(context.Context) (string, error) + InstanceName(context.Context) (string, error) + Zone(context.Context) (string, error) } // A MetadataClient fetches metadata from the GCE Metadata API. type MetadataClient struct{} // ProjectID returns the project ID of the GCE instance. -func (c MetadataClient) ProjectID() (string, error) { - return metadata.ProjectIDWithContext(context.Background()) +func (c MetadataClient) ProjectID(ctx context.Context) (string, error) { + return metadata.ProjectIDWithContext(ctx) } // InstanceName returns the instance name of the GCE instance. -func (c MetadataClient) InstanceName() (string, error) { - return metadata.InstanceNameWithContext(context.Background()) +func (c MetadataClient) InstanceName(ctx context.Context) (string, error) { + return metadata.InstanceNameWithContext(ctx) } // Zone returns the zone the GCE instance is located in. -func (c MetadataClient) Zone() (string, error) { - return metadata.ZoneWithContext(context.Background()) +func (c MetadataClient) Zone(ctx context.Context) (string, error) { + return metadata.ZoneWithContext(ctx) } diff --git a/internal/attestation/gcp/snp/issuer.go b/internal/attestation/gcp/snp/issuer.go index ff5b2fd16a..215b21c8d7 100644 --- a/internal/attestation/gcp/snp/issuer.go +++ b/internal/attestation/gcp/snp/issuer.go @@ -57,7 +57,7 @@ func getAttestationKey(tpm io.ReadWriter) (*tpmclient.Key, error) { // getInstanceInfo generates an extended SNP report, i.e. the report and any loaded certificates. // Report generation is triggered by sending ioctl syscalls to the SNP guest device, the AMD PSP generates the report. // The returned bytes will be written into the attestation document. -func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte) ([]byte, error) { +func getInstanceInfo(ctx context.Context, _ io.ReadWriteCloser, extraData []byte) ([]byte, error) { if len(extraData) > 64 { return nil, fmt.Errorf("extra data too long: %d, should be 64 bytes at most", len(extraData)) } @@ -74,7 +74,7 @@ func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte) return nil, fmt.Errorf("parsing vcek: %w", err) } - gceInstanceInfo, err := gceInstanceInfo() + gceInstanceInfo, err := gceInstanceInfo(ctx) if err != nil { return nil, fmt.Errorf("getting GCE instance info: %w", err) } @@ -93,20 +93,20 @@ func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte) } // gceInstanceInfo returns the instance info for a GCE instance from the metadata API. -func gceInstanceInfo() (*attest.GCEInstanceInfo, error) { +func gceInstanceInfo(ctx context.Context) (*attest.GCEInstanceInfo, error) { c := gcp.MetadataClient{} - instanceName, err := c.InstanceName() + instanceName, err := c.InstanceName(ctx) if err != nil { return nil, fmt.Errorf("getting instance name: %w", err) } - projectID, err := c.ProjectID() + projectID, err := c.ProjectID(ctx) if err != nil { return nil, fmt.Errorf("getting project ID: %w", err) } - zone, err := c.Zone() + zone, err := c.Zone(ctx) if err != nil { return nil, fmt.Errorf("getting zone: %w", err) }