Skip to content

Commit

Permalink
Implement Azure TDX attestation primitives
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Weiße <[email protected]>
  • Loading branch information
daniel-weisse committed Jan 17, 2024
1 parent 6259815 commit d523afc
Show file tree
Hide file tree
Showing 20 changed files with 770 additions and 243 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ require (
github.com/google/go-attestation v0.5.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/go-containerregistry v0.15.2 // indirect
github.com/google/go-tdx-guest v0.2.3-0.20231011100059-4cf02bed9d33 // indirect
github.com/google/go-tdx-guest v0.2.3-0.20231011100059-4cf02bed9d33
github.com/google/go-tspi v0.3.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/logger v1.1.1 // indirect
Expand Down
19 changes: 19 additions & 0 deletions internal/attestation/azure/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//bazel/go:go_test.bzl", "go_test")

go_library(
name = "azure",
srcs = ["azure.go"],
importpath = "github.com/edgelesssys/constellation/v2/internal/attestation/azure",
visibility = ["//:__subpackages__"],
deps = [
"@com_github_google_go_tpm//legacy/tpm2",
"@com_github_google_go_tpm_tools//client",
],
)

go_test(
name = "azure_test",
srcs = ["azure_test.go"],
embed = [":azure"],
deps = [
"//internal/attestation/simulator",
"//internal/attestation/snp",
"@com_github_google_go_tpm//legacy/tpm2",
"@com_github_google_go_tpm_tools//client",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],
)
92 changes: 92 additions & 0 deletions internal/attestation/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,95 @@ Constellation supports multiple attestation technologies on Azure.
Basic TPM attestation.
*/
package azure

import (
"bytes"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"

tpmclient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm/legacy/tpm2"
)

const (
// tpmAkIdx is the NV index of the attestation key used by Azure VMs.
tpmAkIdx = 0x81000003
)

// GetAttestationKey reads the attestation key put into the TPM during early boot.
func GetAttestationKey(tpm io.ReadWriter) (*tpmclient.Key, error) {
ak, err := tpmclient.LoadCachedKey(tpm, tpmAkIdx, tpmclient.NullSession{})
if err != nil {
return nil, fmt.Errorf("reading HCL attestation key from TPM: %w", err)
}

return ak, nil
}

// HCLAkValidator validates an attestation key issued by the Host Compatibility Layer (HCL).
// The HCL is written by Azure, and sits between the Hypervisor and CVM OS.
// The HCL runs in the protected context of the CVM.
type HCLAkValidator struct{}

// Validate validates that the attestation key from the TPM is trustworthy. The steps are:
// 1. runtime data read from the TPM has the same sha256 digest as reported in `report_data` of the SNP report or `TdQuoteBody.ReportData` of the TDX report.
// 2. modulus reported in runtime data matches modulus from key at idx 0x81000003.
// 3. exponent reported in runtime data matches exponent from key at idx 0x81000003.
// The function is currently tested manually on a Azure Ubuntu CVM.
func (a *HCLAkValidator) Validate(runtimeDataRaw []byte, reportData []byte, rsaParameters *tpm2.RSAParams) error {
var rtData runtimeData
if err := json.Unmarshal(runtimeDataRaw, &rtData); err != nil {
return fmt.Errorf("unmarshalling json: %w", err)
}

sum := sha256.Sum256(runtimeDataRaw)
if len(reportData) < len(sum) {
return fmt.Errorf("reportData has unexpected size: %d", len(reportData))
}
if !bytes.Equal(sum[:], reportData[:len(sum)]) {
return errors.New("unexpected runtimeData digest in TPM")
}

if len(rtData.PublicPart) < 1 {
return errors.New("did not receive any keys in runtime data")
}
rawN, err := base64.RawURLEncoding.DecodeString(rtData.PublicPart[0].N)
if err != nil {
return fmt.Errorf("decoding modulus string: %w", err)
}
if !bytes.Equal(rawN, rsaParameters.ModulusRaw) {
return fmt.Errorf("unexpected modulus value in TPM")
}

rawE, err := base64.RawURLEncoding.DecodeString(rtData.PublicPart[0].E)
if err != nil {
return fmt.Errorf("decoding exponent string: %w", err)
}
paddedRawE := make([]byte, 4)
copy(paddedRawE, rawE)
exponent := binary.LittleEndian.Uint32(paddedRawE)

// According to this comment [1] the TPM uses "0" to represent the default exponent "65537".
// The go tpm library also reports the exponent as 0. Thus we have to handle it specially.
// [1] https://github.com/tpm2-software/tpm2-tools/pull/1973#issue-596685005
if !((exponent == 65537 && rsaParameters.ExponentRaw == 0) || exponent == rsaParameters.ExponentRaw) {
return fmt.Errorf("unexpected N value in TPM")
}

return nil
}

type runtimeData struct {
PublicPart []akPub `json:"keys"`
}

// akPub are the public parameters of an RSA attestation key.
type akPub struct {
E string `json:"e"`
N string `json:"n"`
}
165 changes: 165 additions & 0 deletions internal/attestation/azure/azure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/

package azure

import (
"bytes"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"os"
"testing"

"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
"github.com/edgelesssys/constellation/v2/internal/attestation/snp"
"github.com/google/go-tpm-tools/client"
tpmclient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// TestValidateAk tests the attestation key validation with a simulated TPM device.
func TestValidateAk(t *testing.T) {
cgo := os.Getenv("CGO_ENABLED")
if cgo == "0" {
t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
}

int32ToBytes := func(val uint32) []byte {
r := make([]byte, 4)
binary.PutUvarint(r, uint64(val))
return r
}

require := require.New(t)

tpm, err := simulator.OpenSimulatedTPM()
require.NoError(err)
defer tpm.Close()
key, err := client.AttestationKeyRSA(tpm)
require.NoError(err)
defer key.Close()

e := base64.RawURLEncoding.EncodeToString(int32ToBytes(key.PublicArea().RSAParameters.ExponentRaw))
n := base64.RawURLEncoding.EncodeToString(key.PublicArea().RSAParameters.ModulusRaw)
ak := akPub{E: e, N: n}
rtData := runtimeData{PublicPart: []akPub{ak}}

defaultRuntimeDataRaw, err := json.Marshal(rtData)
require.NoError(err)
defaultInstanceInfo := snp.InstanceInfo{Azure: &snp.AzureInstanceInfo{RuntimeData: defaultRuntimeDataRaw}}

sig := sha256.Sum256(defaultRuntimeDataRaw)
defaultReportData := sig[:]
defaultRsaParams := key.PublicArea().RSAParameters

testCases := map[string]struct {
instanceInfo snp.InstanceInfo
runtimeDataRaw []byte
reportData []byte
rsaParameters *tpm2.RSAParams
wantErr bool
}{
"success": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: defaultRuntimeDataRaw,
reportData: defaultReportData,
rsaParameters: defaultRsaParams,
},
"invalid json": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: []byte(""),
reportData: defaultReportData,
rsaParameters: defaultRsaParams,
wantErr: true,
},
"invalid hash": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: defaultRuntimeDataRaw,
reportData: bytes.Repeat([]byte{0}, 64),
rsaParameters: defaultRsaParams,
wantErr: true,
},
"invalid E": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: defaultRuntimeDataRaw,
reportData: defaultReportData,
rsaParameters: func() *tpm2.RSAParams {
tmp := *defaultRsaParams
tmp.ExponentRaw = 1
return &tmp
}(),
wantErr: true,
},
"invalid N": {
instanceInfo: defaultInstanceInfo,
runtimeDataRaw: defaultRuntimeDataRaw,
reportData: defaultReportData,
rsaParameters: func() *tpm2.RSAParams {
tmp := *defaultRsaParams
tmp.ModulusRaw = []byte{0, 1, 2, 3}
return &tmp
}(),
wantErr: true,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ak := HCLAkValidator{}
err = ak.Validate(tc.runtimeDataRaw, tc.reportData, tc.rsaParameters)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

// TestGetHCLAttestationKey is a basic smoke test that only checks if GetAttestationKey can be run error free.
// Testing anything else will only verify that the simulator works as expected, since GetAttestationKey
// only retrieves the attestation key from the TPM.
func TestGetHCLAttestationKey(t *testing.T) {
cgo := os.Getenv("CGO_ENABLED")
if cgo == "0" {
t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
}
require := require.New(t)
assert := assert.New(t)

tpm, err := simulator.OpenSimulatedTPM()
require.NoError(err)
defer tpm.Close()

// we should receive an error if no key was saved at index `tpmAkIdx`
_, err = GetAttestationKey(tpm)
assert.Error(err)

// create a key at the index
tpmAk, err := tpmclient.NewCachedKey(tpm, tpm2.HandleOwner, tpm2.Public{
Type: tpm2.AlgRSA,
NameAlg: tpm2.AlgSHA256,
Attributes: tpm2.FlagFixedTPM | tpm2.FlagFixedParent | tpm2.FlagSensitiveDataOrigin | tpm2.FlagUserWithAuth | tpm2.FlagNoDA | tpm2.FlagRestricted | tpm2.FlagSign,
RSAParameters: &tpm2.RSAParams{
Sign: &tpm2.SigScheme{
Alg: tpm2.AlgRSASSA,
Hash: tpm2.AlgSHA256,
},
KeyBits: 2048,
},
}, tpmAkIdx)
require.NoError(err)
defer tpmAk.Close()

// we should now be able to retrieve the key
_, err = GetAttestationKey(tpm)
assert.NoError(err)
}
2 changes: 1 addition & 1 deletion internal/attestation/azure/snp/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ go_library(
visibility = ["//:__subpackages__"],
deps = [
"//internal/attestation",
"//internal/attestation/azure",
"//internal/attestation/idkeydigest",
"//internal/attestation/snp",
"//internal/attestation/variant",
Expand All @@ -28,7 +29,6 @@ go_library(
"@com_github_google_go_sev_guest//verify",
"@com_github_google_go_sev_guest//verify/trust",
"@com_github_google_go_tpm//legacy/tpm2",
"@com_github_google_go_tpm_tools//client",
"@com_github_google_go_tpm_tools//proto/attest",
],
)
Expand Down
16 changes: 2 additions & 14 deletions internal/attestation/azure/snp/issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ import (
"io"

"github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure"
"github.com/edgelesssys/constellation/v2/internal/attestation/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/go-azguestattestation/maa"
tpmclient "github.com/google/go-tpm-tools/client"
)

const tpmAkIdx = 0x81000003

// Issuer for Azure TPM attestation.
type Issuer struct {
variant.AzureSEVSNP
Expand All @@ -40,7 +38,7 @@ func NewIssuer(log attestation.Logger) *Issuer {

i.Issuer = vtpm.NewIssuer(
vtpm.OpenVTPM,
getAttestationKey,
azure.GetAttestationKey,
i.getInstanceInfo,
log,
)
Expand Down Expand Up @@ -83,16 +81,6 @@ func (i *Issuer) getInstanceInfo(ctx context.Context, tpm io.ReadWriteCloser, us
return statement, nil
}

// getAttestationKey reads the attestation key put into the TPM during early boot.
func getAttestationKey(tpm io.ReadWriter) (*tpmclient.Key, error) {
ak, err := tpmclient.LoadCachedKey(tpm, tpmAkIdx, tpmclient.NullSession{})
if err != nil {
return nil, fmt.Errorf("reading HCL attestation key from TPM: %w", err)
}

return ak, nil
}

type imdsAPI interface {
getMAAURL(ctx context.Context) (string, error)
}
Expand Down
Loading

0 comments on commit d523afc

Please sign in to comment.