Skip to content

Commit

Permalink
gcp: pass context to metadata functions (#3228)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Weiße <[email protected]>
  • Loading branch information
daniel-weisse authored Jul 3, 2024
1 parent 7b6c3a7 commit 20269ab
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
6 changes: 3 additions & 3 deletions internal/attestation/gcp/es/issuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ type fakeMetadataClient struct {
zoneErr error
}

func (c fakeMetadataClient) ProjectID() (string, error) {
func (c fakeMetadataClient) ProjectID(_ context.Context) (string, error) {
return c.projectIDString, c.projecIDErr
}

func (c fakeMetadataClient) InstanceName() (string, error) {
func (c fakeMetadataClient) InstanceName(_ context.Context) (string, error) {
return c.instanceNameString, c.instanceNameErr
}

func (c fakeMetadataClient) Zone() (string, error) {
func (c fakeMetadataClient) Zone(_ context.Context) (string, error) {
return c.zoneString, c.zoneErr
}
28 changes: 14 additions & 14 deletions internal/attestation/gcp/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ func GCEInstanceInfo(client gcpMetadataClient) func(context.Context, io.ReadWrit
// Ideally we would want to use the endorsement public key certificate
// However, this is not available on GCE instances
// Workaround: Provide ShieldedVM instance info
// The attestating party can request the VMs signing key using Google's API
return func(context.Context, io.ReadWriteCloser, []byte) ([]byte, error) {
projectID, err := client.ProjectID()
// The attesting party can request the VMs signing key using Google's API
return func(ctx context.Context, _ io.ReadWriteCloser, _ []byte) ([]byte, error) {
projectID, err := client.ProjectID(ctx)
if err != nil {
return nil, errors.New("unable to fetch projectID")
}
zone, err := client.Zone()
zone, err := client.Zone(ctx)
if err != nil {
return nil, errors.New("unable to fetch zone")
}
instanceName, err := client.InstanceName()
instanceName, err := client.InstanceName(ctx)
if err != nil {
return nil, errors.New("unable to fetch instance name")
}
Expand All @@ -45,25 +45,25 @@ func GCEInstanceInfo(client gcpMetadataClient) func(context.Context, io.ReadWrit
}

type gcpMetadataClient interface {
ProjectID() (string, error)
InstanceName() (string, error)
Zone() (string, error)
ProjectID(context.Context) (string, error)
InstanceName(context.Context) (string, error)
Zone(context.Context) (string, error)
}

// A MetadataClient fetches metadata from the GCE Metadata API.
type MetadataClient struct{}

// ProjectID returns the project ID of the GCE instance.
func (c MetadataClient) ProjectID() (string, error) {
return metadata.ProjectIDWithContext(context.Background())
func (c MetadataClient) ProjectID(ctx context.Context) (string, error) {
return metadata.ProjectIDWithContext(ctx)
}

// InstanceName returns the instance name of the GCE instance.
func (c MetadataClient) InstanceName() (string, error) {
return metadata.InstanceNameWithContext(context.Background())
func (c MetadataClient) InstanceName(ctx context.Context) (string, error) {
return metadata.InstanceNameWithContext(ctx)
}

// Zone returns the zone the GCE instance is located in.
func (c MetadataClient) Zone() (string, error) {
return metadata.ZoneWithContext(context.Background())
func (c MetadataClient) Zone(ctx context.Context) (string, error) {
return metadata.ZoneWithContext(ctx)
}
12 changes: 6 additions & 6 deletions internal/attestation/gcp/snp/issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func getAttestationKey(tpm io.ReadWriter) (*tpmclient.Key, error) {
// getInstanceInfo generates an extended SNP report, i.e. the report and any loaded certificates.
// Report generation is triggered by sending ioctl syscalls to the SNP guest device, the AMD PSP generates the report.
// The returned bytes will be written into the attestation document.
func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte) ([]byte, error) {
func getInstanceInfo(ctx context.Context, _ io.ReadWriteCloser, extraData []byte) ([]byte, error) {
if len(extraData) > 64 {
return nil, fmt.Errorf("extra data too long: %d, should be 64 bytes at most", len(extraData))
}
Expand All @@ -74,7 +74,7 @@ func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte)
return nil, fmt.Errorf("parsing vcek: %w", err)
}

gceInstanceInfo, err := gceInstanceInfo()
gceInstanceInfo, err := gceInstanceInfo(ctx)
if err != nil {
return nil, fmt.Errorf("getting GCE instance info: %w", err)
}
Expand All @@ -93,20 +93,20 @@ func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte)
}

// gceInstanceInfo returns the instance info for a GCE instance from the metadata API.
func gceInstanceInfo() (*attest.GCEInstanceInfo, error) {
func gceInstanceInfo(ctx context.Context) (*attest.GCEInstanceInfo, error) {
c := gcp.MetadataClient{}

instanceName, err := c.InstanceName()
instanceName, err := c.InstanceName(ctx)
if err != nil {
return nil, fmt.Errorf("getting instance name: %w", err)
}

projectID, err := c.ProjectID()
projectID, err := c.ProjectID(ctx)
if err != nil {
return nil, fmt.Errorf("getting project ID: %w", err)
}

zone, err := c.Zone()
zone, err := c.Zone(ctx)
if err != nil {
return nil, fmt.Errorf("getting zone: %w", err)
}
Expand Down

0 comments on commit 20269ab

Please sign in to comment.