Skip to content

Commit

Permalink
coordinator: don't share CA instance with gRPC servers
Browse files Browse the repository at this point in the history
  • Loading branch information
burgerdev committed Jun 6, 2024
1 parent de5e0f9 commit 7f7deaf
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 143 deletions.
82 changes: 55 additions & 27 deletions coordinator/mesh.go → coordinator/internal/authority/authority.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2024 Edgeless Systems GmbH
// SPDX-License-Identifier: AGPL-3.0-only

package main
package authority

import (
"context"
Expand All @@ -23,24 +23,38 @@ import (
"github.com/google/go-sev-guest/validate"
)

type meshAuthority struct {
ca *ca.CA
certs map[string][]byte
certsMux sync.RWMutex
manifests appendableList[*manifest.Manifest]
logger *slog.Logger
// Bundle is a set of PEM-encoded certificates for Contrast workloads.
type Bundle struct {
WorkloadCert []byte
MeshCA []byte
IntermediateCA []byte
RootCA []byte
}

func newMeshAuthority(ca *ca.CA, log *slog.Logger) *meshAuthority {
return &meshAuthority{
ca: ca,
certs: make(map[string][]byte),
// Authority manages the manifest state of Contrast.
type Authority struct {
ca *ca.CA
bundles map[string]Bundle
bundlesMux sync.RWMutex
manifests appendableList[*manifest.Manifest]
logger *slog.Logger
}

// New creates a new Authority instance.
func New(caInstance *ca.CA, log *slog.Logger) *Authority {
return &Authority{
ca: caInstance,
bundles: make(map[string]Bundle),
manifests: new(appendable.Appendable[*manifest.Manifest]),
logger: log.WithGroup("mesh-authority"),
}
}

func (m *meshAuthority) SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) {
// SNPValidateOpts returns SNP validation options from reference values.
//
// It also ensures that the policy hash in the report's HOSTDATA is allowed by the current
// manifest.
func (m *Authority) SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) {
mnfst, err := m.manifests.Latest()
if err != nil {
return nil, fmt.Errorf("getting latest manifest: %w", err)
Expand Down Expand Up @@ -83,13 +97,16 @@ func (m *meshAuthority) SNPValidateOpts(report *sevsnp.Report) (*validate.Option
}, nil
}

func (m *meshAuthority) ValidateCallback(_ context.Context, report *sevsnp.Report,
// ValidateCallback creates a certificate bundle for the verified client.
func (m *Authority) ValidateCallback(_ context.Context, report *sevsnp.Report,
_ asn1.ObjectIdentifier, _, _, peerPubKeyBytes []byte,
) error {
mnfst, err := m.manifests.Latest()
if err != nil {
return fmt.Errorf("getting latest manifest: %w", err)
}
// TODO(burgerdev): The CA should be tied to the manifest.
caInstance := m.ca

hostData := manifest.NewHexString(report.HostData)
dnsNames, ok := mnfst.Policies[hostData]
Expand All @@ -106,7 +123,7 @@ func (m *meshAuthority) ValidateCallback(_ context.Context, report *sevsnp.Repor
if err != nil {
return fmt.Errorf("failed to construct extensions: %w", err)
}
cert, err := m.ca.NewAttestedMeshCert(dnsNames, extensions, peerPubKey)
cert, err := caInstance.NewAttestedMeshCert(dnsNames, extensions, peerPubKey)
if err != nil {
return fmt.Errorf("failed to issue new attested mesh cert: %w", err)
}
Expand All @@ -115,38 +132,49 @@ func (m *meshAuthority) ValidateCallback(_ context.Context, report *sevsnp.Repor
peerPublicKeyHashStr := hex.EncodeToString(peerPubKeyHash[:])
m.logger.Info("Validated peer", "peerPublicKeyHashStr", peerPublicKeyHashStr)

m.certsMux.Lock()
defer m.certsMux.Unlock()
m.certs[peerPublicKeyHashStr] = cert
m.bundlesMux.Lock()
defer m.bundlesMux.Unlock()
m.bundles[peerPublicKeyHashStr] = Bundle{
WorkloadCert: cert,
MeshCA: caInstance.GetMeshCACert(),
IntermediateCA: caInstance.GetIntermCACert(),
RootCA: caInstance.GetRootCACert(),
}

return nil
}

func (m *meshAuthority) GetCert(peerPublicKeyHashStr string) ([]byte, error) {
m.certsMux.RLock()
defer m.certsMux.RUnlock()
// GetCertBundle retrieves the certificate bundle created for the peer identified by the given public key.
func (m *Authority) GetCertBundle(peerPublicKeyHashStr string) (Bundle, error) {
m.bundlesMux.RLock()
defer m.bundlesMux.RUnlock()

bundle, ok := m.bundles[peerPublicKeyHashStr]

cert, ok := m.certs[peerPublicKeyHashStr]
if !ok {
return nil, fmt.Errorf("cert for peer public key %s not found", peerPublicKeyHashStr)
return Bundle{}, fmt.Errorf("cert for peer public key %s not found", peerPublicKeyHashStr)
}

return cert, nil
return bundle, nil
}

func (m *meshAuthority) GetManifests() []*manifest.Manifest {
return m.manifests.All()
// GetManifestsAndLatestCA retrieves the manifest history and the currently active CA instance.
func (m *Authority) GetManifestsAndLatestCA() ([]*manifest.Manifest, *ca.CA) {
// TODO(burgerdev): The CA should be tied to the manifest.
return m.manifests.All(), m.ca
}

func (m *meshAuthority) SetManifest(mnfst *manifest.Manifest) error {
// SetManifest updates the active manifest.
func (m *Authority) SetManifest(mnfst *manifest.Manifest) error {
if err := m.ca.RotateIntermCerts(); err != nil {
return fmt.Errorf("rotating intermediate certificates: %w", err)
}
m.manifests.Append(mnfst)
return nil
}

func (m *meshAuthority) LatestManifest() (*manifest.Manifest, error) {
// LatestManifest retrieves the active manifest.
func (m *Authority) LatestManifest() (*manifest.Manifest, error) {
return m.manifests.Latest()
}

Expand Down
7 changes: 4 additions & 3 deletions coordinator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"os"

"github.com/edgelesssys/contrast/coordinator/internal/authority"
"github.com/edgelesssys/contrast/internal/ca"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/meshapi"
Expand Down Expand Up @@ -56,9 +57,9 @@ func run() (retErr error) {

promRegistry := prometheus.NewRegistry()

meshAuth := newMeshAuthority(caInstance, logger)
userAPI := newUserAPIServer(meshAuth, caInstance, promRegistry, logger)
meshAPI := newMeshAPIServer(meshAuth, caInstance, promRegistry, logger)
meshAuth := authority.New(caInstance, logger)
userAPI := newUserAPIServer(meshAuth, promRegistry, logger)
meshAPI := newMeshAPIServer(meshAuth, meshAuth, promRegistry, logger)

eg := errgroup.Group{}

Expand Down
71 changes: 46 additions & 25 deletions coordinator/meshapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ package main

import (
"context"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"fmt"
"log/slog"
"net"
"time"

"github.com/edgelesssys/contrast/coordinator/internal/authority"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/grpc/atlscredentials"
Expand All @@ -20,27 +24,26 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"google.golang.org/grpc/peer"
"k8s.io/utils/clock"
)

type meshAPIServer struct {
grpc *grpc.Server
certGet certGetter
caChainGetter certChainGetter
ticker clock.Ticker
logger *slog.Logger
grpc *grpc.Server
bundleGetter certBundleGetter
ticker clock.Ticker
logger *slog.Logger

meshapi.UnimplementedMeshAPIServer
}

type certGetter interface {
GetCert(peerPublicKeyHashStr string) ([]byte, error)
type certBundleGetter interface {
GetCertBundle(peerPublicKeyHashStr string) (authority.Bundle, error)
}

func newMeshAPIServer(meshAuth *meshAuthority, caGetter certChainGetter, reg *prometheus.Registry, log *slog.Logger) *meshAPIServer {
func newMeshAPIServer(meshAuth *authority.Authority, bundleGetter certBundleGetter, reg *prometheus.Registry, log *slog.Logger) *meshAPIServer {
ticker := clock.RealClock{}.NewTicker(24 * time.Hour)
kdsGetter := snp.NewCachedHTTPSGetter(memstore.New[string, []byte](), ticker, logger.NewNamed(log, "kds-getter"))

Expand Down Expand Up @@ -74,11 +77,10 @@ func newMeshAPIServer(meshAuth *meshAuthority, caGetter certChainGetter, reg *pr
),
)
s := &meshAPIServer{
grpc: grpcServer,
certGet: meshAuth,
caChainGetter: caGetter,
ticker: ticker,
logger: log.WithGroup("meshapi"),
grpc: grpcServer,
bundleGetter: bundleGetter,
ticker: ticker,
logger: log.WithGroup("meshapi"),
}
meshapi.RegisterMeshAPIServer(s.grpc, s)

Expand All @@ -98,22 +100,41 @@ func (i *meshAPIServer) Serve(endpoint string) error {
return i.grpc.Serve(lis)
}

func (i *meshAPIServer) NewMeshCert(_ context.Context, req *meshapi.NewMeshCertRequest,
) (*meshapi.NewMeshCertResponse, error) {
func (i *meshAPIServer) NewMeshCert(ctx context.Context, _ *meshapi.NewMeshCertRequest) (*meshapi.NewMeshCertResponse, error) {
i.logger.Info("NewMeshCert called")

cert, err := i.certGet.GetCert(req.PeerPublicKeyHash)
// Fetch the peer public key from gRPC's TLS context and look up the corresponding cetificate.

p, ok := peer.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("failed to get peer from context")
}

tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo)
if !ok {
return nil, fmt.Errorf("failed to get TLS info from peer")
}

if len(tlsInfo.State.PeerCertificates) == 0 {
return nil, fmt.Errorf("no peer certificates found")
}

peerCert := tlsInfo.State.PeerCertificates[0]
peerPubKeyBytes, err := x509.MarshalPKIXPublicKey(peerCert.PublicKey)
if err != nil {
return nil, status.Errorf(codes.Internal,
"getting certificate with public key hash %q: %v", req.PeerPublicKeyHash, err)
return nil, fmt.Errorf("could not marshal public key: %w", err)
}
peerPubKeyHash := sha256.Sum256(peerPubKeyBytes)
peerPublicKeyHashStr := hex.EncodeToString(peerPubKeyHash[:])

meshCACert := i.caChainGetter.GetMeshCACert()
intermCert := i.caChainGetter.GetIntermCACert()
bundle, err := i.bundleGetter.GetCertBundle(peerPublicKeyHashStr)
if err != nil {
return nil, fmt.Errorf("server did not create a bundle for ")
}

return &meshapi.NewMeshCertResponse{
MeshCACert: meshCACert,
CertChain: append(cert, intermCert...),
RootCACert: i.caChainGetter.GetRootCACert(),
MeshCACert: bundle.MeshCA,
CertChain: append(bundle.WorkloadCert, bundle.IntermediateCA...),
RootCACert: bundle.RootCA,
}, nil
}
26 changes: 10 additions & 16 deletions coordinator/userapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/edgelesssys/contrast/internal/appendable"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/ca"
"github.com/edgelesssys/contrast/internal/grpc/atlscredentials"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
Expand All @@ -42,15 +43,14 @@ type userAPIServer struct {
grpc *grpc.Server
policyTextStore store[manifest.HexString, manifest.Policy]
manifSetGetter manifestSetGetter
caChainGetter certChainGetter
logger *slog.Logger
mux sync.RWMutex
metrics userAPIMetrics

userapi.UnimplementedUserAPIServer
}

func newUserAPIServer(mSGetter manifestSetGetter, caGetter certChainGetter, reg *prometheus.Registry, log *slog.Logger) *userAPIServer {
func newUserAPIServer(mSGetter manifestSetGetter, reg *prometheus.Registry, log *slog.Logger) *userAPIServer {
issuer := snp.NewIssuer(logger.NewNamed(log, "snp-issuer"))
credentials := atlscredentials.New(issuer, nil)

Expand Down Expand Up @@ -84,7 +84,6 @@ func newUserAPIServer(mSGetter manifestSetGetter, caGetter certChainGetter, reg
grpc: grpcServer,
policyTextStore: memstore.New[manifest.HexString, manifest.Policy](),
manifSetGetter: mSGetter,
caChainGetter: caGetter,
logger: log.WithGroup("userapi"),
metrics: userAPIMetrics{
manifestGeneration: manifestGeneration,
Expand Down Expand Up @@ -141,11 +140,12 @@ func (s *userAPIServer) SetManifest(ctx context.Context, req *userapi.SetManifes
return nil, status.Errorf(codes.Internal, "setting manifest: %v", err)
}

s.metrics.manifestGeneration.Set(float64(len(s.manifSetGetter.GetManifests())))
manifests, ca := s.manifSetGetter.GetManifestsAndLatestCA()
s.metrics.manifestGeneration.Set(float64(len(manifests)))

resp := &userapi.SetManifestResponse{
RootCA: s.caChainGetter.GetRootCACert(),
MeshCA: s.caChainGetter.GetMeshCACert(),
RootCA: ca.GetRootCACert(),
MeshCA: ca.GetMeshCACert(),
}

s.logger.Info("SetManifest succeeded")
Expand All @@ -158,7 +158,7 @@ func (s *userAPIServer) GetManifests(_ context.Context, _ *userapi.GetManifestsR
s.mux.RLock()
defer s.mux.RUnlock()

manifests := s.manifSetGetter.GetManifests()
manifests, ca := s.manifSetGetter.GetManifestsAndLatestCA()
if len(manifests) == 0 {
return nil, status.Errorf(codes.FailedPrecondition, "no manifests set")
}
Expand All @@ -176,8 +176,8 @@ func (s *userAPIServer) GetManifests(_ context.Context, _ *userapi.GetManifestsR
resp := &userapi.GetManifestsResponse{
Manifests: manifestBytes,
Policies: policySliceToBytesSlice(policies),
RootCA: s.caChainGetter.GetRootCACert(),
MeshCA: s.caChainGetter.GetMeshCACert(),
RootCA: ca.GetRootCACert(),
MeshCA: ca.GetMeshCACert(),
}

s.logger.Info("GetManifest succeeded")
Expand Down Expand Up @@ -252,15 +252,9 @@ func manifestSliceToBytesSlice(s []*manifest.Manifest) ([][]byte, error) {
return manifests, nil
}

type certChainGetter interface {
GetRootCACert() []byte
GetMeshCACert() []byte
GetIntermCACert() []byte
}

type manifestSetGetter interface {
SetManifest(*manifest.Manifest) error
GetManifests() []*manifest.Manifest
GetManifestsAndLatestCA() ([]*manifest.Manifest, *ca.CA)
LatestManifest() (*manifest.Manifest, error)
}

Expand Down
Loading

0 comments on commit 7f7deaf

Please sign in to comment.