Skip to content

Commit

Permalink
Add some tests for TDX attestation
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Weiße <[email protected]>
  • Loading branch information
daniel-weisse committed Jan 18, 2024
1 parent 5d67668 commit 39d4af8
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 20 deletions.
12 changes: 12 additions & 0 deletions internal/attestation/azure/tdx/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//bazel/go:go_test.bzl", "go_test")

go_library(
name = "tdx",
Expand All @@ -24,3 +25,14 @@ go_library(
"@com_github_google_go_tpm_tools//proto/attest",
],
)

go_test(
name = "tdx_test",
srcs = ["issuer_test.go"],
embed = [":tdx"],
deps = [
"//internal/attestation/azure/tdx/testdata",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],
)
43 changes: 23 additions & 20 deletions internal/attestation/azure/tdx/issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,56 +78,59 @@ func (i *Issuer) getInstanceInfo(ctx context.Context, tpm io.ReadWriteCloser, _
return nil, err
}

// Parse the report and get quote
instanceInfo, err := i.getHCLReport(ctx, report)
// Parse the report from the TPM
hwReport, runtimeData, err := parseHCLReport(report)
if err != nil {
return nil, fmt.Errorf("getting HCL report: %w", err)
}

// Get quote from IMDS API
quote, err := i.quoteGetter.getQuote(ctx, hwReport)
if err != nil {
return nil, fmt.Errorf("getting quote: %w", err)
}

instanceInfo := instanceInfo{
AttestationReport: quote,
RuntimeData: runtimeData,
}
instanceInfoJSON, err := json.Marshal(instanceInfo)
if err != nil {
return nil, fmt.Errorf("marshalling instance info: %w", err)
}
return instanceInfoJSON, nil
}

func (i *Issuer) getHCLReport(ctx context.Context, report []byte) (instanceInfo, error) {
func parseHCLReport(report []byte) (hwReport, runtimeData []byte, err error) {
// First, ensure the extracted report is actually for TDX
if len(report) < runtimeDataSizeOffset+4 {
return instanceInfo{}, fmt.Errorf("invalid HCL report: expected at least %d bytes to read HCL report type, got %d", runtimeDataSizeOffset+4, len(report))
if len(report) < hclReportTypeOffsetStart+4 {
return nil, nil, fmt.Errorf("invalid HCL report: expected at least %d bytes to read HCL report type, got %d", runtimeDataSizeOffset+4, len(report))
}
reportType := binary.LittleEndian.Uint32(report[hclReportTypeOffsetStart : hclReportTypeOffsetStart+4])
if reportType != hclReportTypeTDX {
return instanceInfo{}, fmt.Errorf("invalid HCL report type: expected TDX (%d), got %d", hclReportTypeTDX, reportType)
return nil, nil, fmt.Errorf("invalid HCL report type: expected TDX (%d), got %d", hclReportTypeTDX, reportType)
}

// We need the td report (generally called HW report in Azure's samples) from the HCL report to send to the IMDS API
if len(report) < hwReportStart+tdReportSize {
return instanceInfo{}, fmt.Errorf("invalid HCL report: expected at least %d bytes to read td report, got %d", hwReportStart+tdReportSize, len(report))
return nil, nil, fmt.Errorf("invalid HCL report: expected at least %d bytes to read td report, got %d", hwReportStart+tdReportSize, len(report))
}
hwReport := report[hwReportStart : hwReportStart+tdReportSize]
hwReport = report[hwReportStart : hwReportStart+tdReportSize]

// We also need the runtime data to verify the attestation key later on the validator side
if len(report) < runtimeDataSizeOffset+4 {
return instanceInfo{}, fmt.Errorf("invalid HCL report: expected at least %d bytes to read runtime data size, got %d", runtimeDataSizeOffset+4, len(report))
return nil, nil, fmt.Errorf("invalid HCL report: expected at least %d bytes to read runtime data size, got %d", runtimeDataSizeOffset+4, len(report))
}
runtimeDataSize := binary.LittleEndian.Uint32(report[runtimeDataSizeOffset : runtimeDataSizeOffset+4])
if len(report) < runtimeDataOffset+int(runtimeDataSize) {
return instanceInfo{}, fmt.Errorf("invalid HCL report: expected at least %d bytes to read runtime data, got %d", runtimeDataOffset+int(runtimeDataSize), len(report))
return nil, nil, fmt.Errorf("invalid HCL report: expected at least %d bytes to read runtime data, got %d", runtimeDataOffset+int(runtimeDataSize), len(report))
}
runtimeData := report[runtimeDataOffset : runtimeDataOffset+runtimeDataSize]
runtimeData = report[runtimeDataOffset : runtimeDataOffset+runtimeDataSize]

quote, err := i.quoteGetter.getQuote(ctx, hwReport)
if err != nil {
return instanceInfo{}, fmt.Errorf("getting quote: %w", err)
}

return instanceInfo{
AttestationReport: quote,
RuntimeData: runtimeData,
}, nil
return hwReport, runtimeData, nil
}

// imdsQuoteGetter issues TDX quotes using Azure's IMDS API.
type imdsQuoteGetter struct {
client *http.Client
}
Expand Down
158 changes: 158 additions & 0 deletions internal/attestation/azure/tdx/issuer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/

package tdx

import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"io"
"net/http"
"testing"

"github.com/edgelesssys/constellation/v2/internal/attestation/azure/tdx/testdata"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestParseHCLReport(t *testing.T) {
testCases := map[string]struct {
report []byte
wantErr bool
}{
"success using testdata": {
report: testdata.HCLReport,
wantErr: false,
},
"invalid report type": {
report: func() []byte {
report := make([]byte, len(testdata.HCLReport))
copy(report, testdata.HCLReport)
binary.LittleEndian.PutUint32(report[hclReportTypeOffsetStart:], hclReportTypeInvalid)
return report
}(),
wantErr: true,
},
"report too short for HCL report type": {
report: func() []byte {
report := make([]byte, hclReportTypeOffsetStart+3)
copy(report, testdata.HCLReport)
return report
}(),
wantErr: true,
},
"report too short for runtime data size": {
report: func() []byte {
report := make([]byte, runtimeDataSizeOffset+3)
copy(report, testdata.HCLReport)
return report
}(),
wantErr: true,
},
"runtime data shorter than runtime data size": {
report: func() []byte {
report := make([]byte, len(testdata.HCLReport))
copy(report, testdata.HCLReport)
// Lets claim the report contains a much larger runtime data entry than it actually does.
// That way, we can easily test if our code correctly handles reports that are shorter than
// what they claim to be and avoid panics.
binary.LittleEndian.PutUint32(report[runtimeDataSizeOffset:], 0xFFFF)
return report
}(),
wantErr: true,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)

hwReport, runtimeData, err := parseHCLReport(tc.report)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.NotNil(hwReport)
assert.NotNil(runtimeData)
})
}
}

func TestIMDSGetQuote(t *testing.T) {
testCases := map[string]struct {
client *http.Client
wantErr bool
}{
"success": {
client: newTestClient(func(req *http.Request) *http.Response {
quote := quoteResponse{
Quote: "test",
}
b, err := json.Marshal(quote)
require.NoError(t, err)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(b)),
}
},
),
wantErr: false,
},
"bad status code": {
client: newTestClient(func(req *http.Request) *http.Response {
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewBufferString("")),
}
},
),
wantErr: true,
},
"bad json": {
client: newTestClient(func(req *http.Request) *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("")),
}
},
),
wantErr: true,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)

quoteGetter := imdsQuoteGetter{
client: tc.client,
}

_, err := quoteGetter.getQuote(context.Background(), []byte("test"))
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

type roundTripFunc func(req *http.Request) *http.Response

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}

// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}
9 changes: 9 additions & 0 deletions internal/attestation/azure/tdx/testdata/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "testdata",
srcs = ["testdata.go"],
embedsrcs = ["hclreport.bin"],
importpath = "github.com/edgelesssys/constellation/v2/internal/attestation/azure/tdx/testdata",
visibility = ["//:__subpackages__"],
)
Binary file not shown.
15 changes: 15 additions & 0 deletions internal/attestation/azure/tdx/testdata/testdata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/

// Package testdata contains testing data for an attestation process.
package testdata

import _ "embed"

// HCLReport is an example HCL report from an Azure TDX VM.
//
//go:embed hclreport.bin
var HCLReport []byte
6 changes: 6 additions & 0 deletions internal/attestation/choose/choose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ func TestIssuer(t *testing.T) {
"azure-sev-snp": {
variant: variant.AzureSEVSNP{},
},
"azure-tdx": {
variant: variant.AzureTDX{},
},
"azure-trusted-launch": {
variant: variant.AzureTrustedLaunch{},
},
Expand Down Expand Up @@ -77,6 +80,9 @@ func TestValidator(t *testing.T) {
"azure-sev-snp": {
cfg: &config.AzureSEVSNP{},
},
"azure-tdx": {
cfg: &config.AzureTDX{},
},
"azure-trusted-launch": {
cfg: &config.AzureTrustedLaunch{},
},
Expand Down

0 comments on commit 39d4af8

Please sign in to comment.