From 383c8a8eb7a184b2a1bf3fa456801a53dda3c1c7 Mon Sep 17 00:00:00 2001 From: Leonard Cohnen Date: Tue, 13 Aug 2024 15:20:30 +0200 Subject: [PATCH] meshapi: pass seedengine via getter --- coordinator/internal/authority/authority.go | 9 +++++++ coordinator/internal/authority/credentials.go | 23 +++-------------- coordinator/main.go | 7 +++++- coordinator/meshapi.go | 25 +++++++++++++------ 4 files changed, 35 insertions(+), 29 deletions(-) diff --git a/coordinator/internal/authority/authority.go b/coordinator/internal/authority/authority.go index 82cb45b033..21b8c2cc94 100644 --- a/coordinator/internal/authority/authority.go +++ b/coordinator/internal/authority/authority.go @@ -163,6 +163,15 @@ func (m *Authority) walkTransitions(transitionRef [history.HashSize]byte, consum return nil } +// GetSeedEngine returns the seed engine. +func (m *Authority) GetSeedEngine() (*seedengine.SeedEngine, error) { + se := m.se.Load() + if se == nil { + return nil, errors.New("seed engine not initialized") + } + return se, nil +} + // State is a snapshot of the Coordinator's manifest history. type State struct { Manifest *manifest.Manifest diff --git a/coordinator/internal/authority/credentials.go b/coordinator/internal/authority/credentials.go index baae6ce576..6bd1844555 100644 --- a/coordinator/internal/authority/credentials.go +++ b/coordinator/internal/authority/credentials.go @@ -12,7 +12,6 @@ import ( "net" "time" - "github.com/edgelesssys/contrast/coordinator/internal/seedengine" "github.com/edgelesssys/contrast/internal/atls" "github.com/edgelesssys/contrast/internal/attestation/snp" "github.com/edgelesssys/contrast/internal/logger" @@ -26,9 +25,8 @@ import ( // Credentials are gRPC transport credentials that dynamically update with the Coordinator state. type Credentials struct { - issuer atls.Issuer - getState func() (*State, error) - getSeedEngine func() (*seedengine.SeedEngine, error) + issuer atls.Issuer + getState func() (*State, error) logger *slog.Logger attestationFailuresCounter prometheus.Counter @@ -57,13 +55,6 @@ func (a *Authority) Credentials(reg *prometheus.Registry, issuer atls.Issuer) (* } return state, nil }, - getSeedEngine: func() (*seedengine.SeedEngine, error) { - se := a.se.Load() - if se == nil { - return nil, errors.New("seed engine not initialized") - } - return se, nil - }, logger: a.logger, attestationFailuresCounter: attestationFailuresCounter, kdsGetter: kdsGetter, @@ -79,14 +70,8 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A return nil, nil, fmt.Errorf("getting state: %w", err) } - seedEngine, err := c.getSeedEngine() - if err != nil { - return nil, nil, fmt.Errorf("getting seed engine: %w", err) - } - authInfo := AuthInfo{ - State: state, - SeedEngine: seedEngine, + State: state, } validator := snp.NewValidatorWithCallbacks(state, c.kdsGetter, @@ -142,8 +127,6 @@ type AuthInfo struct { State *State // Report is the attestation report sent by the peer. Report *sevsnp.Report - // SeedEngine is the seed engine used to derive secrets. - SeedEngine *seedengine.SeedEngine } // ValidateCallback takes the validated report and attaches it to the [AuthInfo]. diff --git a/coordinator/main.go b/coordinator/main.go index 306b5667b1..20bd4ca218 100644 --- a/coordinator/main.go +++ b/coordinator/main.go @@ -15,6 +15,7 @@ import ( "github.com/edgelesssys/contrast/coordinator/history" "github.com/edgelesssys/contrast/coordinator/internal/authority" + "github.com/edgelesssys/contrast/coordinator/internal/seedengine" "github.com/edgelesssys/contrast/internal/atls" "github.com/edgelesssys/contrast/internal/grpc/atlscredentials" "github.com/edgelesssys/contrast/internal/logger" @@ -76,7 +77,11 @@ func run() (retErr error) { userapi.RegisterUserAPIServer(grpcServer, meshAuth) serverMetrics.InitializeMetrics(grpcServer) - meshAPI := newMeshAPIServer(meshAuth, promRegistry, serverMetrics, logger) + meshAPI := newMeshAPIServer(meshAuth, promRegistry, serverMetrics, + func() (*seedengine.SeedEngine, error) { + return meshAuth.GetSeedEngine() + }, + logger) metricsServer := &http.Server{} eg, ctx := errgroup.WithContext(ctx) diff --git a/coordinator/meshapi.go b/coordinator/meshapi.go index dd24ceb574..974634a836 100644 --- a/coordinator/meshapi.go +++ b/coordinator/meshapi.go @@ -12,6 +12,7 @@ import ( "time" "github.com/edgelesssys/contrast/coordinator/internal/authority" + "github.com/edgelesssys/contrast/coordinator/internal/seedengine" "github.com/edgelesssys/contrast/internal/atls" "github.com/edgelesssys/contrast/internal/attestation/snp" "github.com/edgelesssys/contrast/internal/manifest" @@ -24,14 +25,17 @@ import ( ) type meshAPIServer struct { - grpc *grpc.Server - cleanup func() - logger *slog.Logger + grpc *grpc.Server + cleanup func() + getSeedEngine func() (*seedengine.SeedEngine, error) + logger *slog.Logger meshapi.UnimplementedMeshAPIServer } -func newMeshAPIServer(meshAuth *authority.Authority, reg *prometheus.Registry, serverMetrics *grpcprometheus.ServerMetrics, log *slog.Logger) *meshAPIServer { +func newMeshAPIServer(meshAuth *authority.Authority, reg *prometheus.Registry, serverMetrics *grpcprometheus.ServerMetrics, + getSeedEngine func() (*seedengine.SeedEngine, error), log *slog.Logger, +) *meshAPIServer { credentials, cancel := meshAuth.Credentials(reg, atls.NoIssuer) grpcServer := grpc.NewServer( @@ -45,9 +49,10 @@ func newMeshAPIServer(meshAuth *authority.Authority, reg *prometheus.Registry, s ), ) s := &meshAPIServer{ - grpc: grpcServer, - cleanup: cancel, - logger: log.WithGroup("meshapi"), + grpc: grpcServer, + cleanup: cancel, + getSeedEngine: getSeedEngine, + logger: log.WithGroup("meshapi"), } meshapi.RegisterMeshAPIServer(s.grpc, s) serverMetrics.InitializeMetrics(s.grpc) @@ -85,7 +90,11 @@ func (i *meshAPIServer) NewMeshCert(ctx context.Context, _ *meshapi.NewMeshCertR state := authInfo.State report := authInfo.Report tlsInfo := authInfo.TLSInfo - seedEngine := authInfo.SeedEngine + + seedEngine, err := i.getSeedEngine() + if err != nil { + return nil, fmt.Errorf("failed to get seed engine: %w", err) + } if len(tlsInfo.State.PeerCertificates) == 0 { return nil, fmt.Errorf("no peer certificates found")