Skip to content

Commit

Permalink
meshapi: pass seedengine via getter
Browse files Browse the repository at this point in the history
  • Loading branch information
3u13r committed Aug 13, 2024
1 parent 2c3fc2a commit fee113c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 29 deletions.
9 changes: 9 additions & 0 deletions coordinator/internal/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ func (m *Authority) walkTransitions(transitionRef [history.HashSize]byte, consum
return nil
}

// GetSeedEngine returns the seed engine.
func (m *Authority) GetSeedEngine() (*seedengine.SeedEngine, error) {
se := m.se.Load()
if se == nil {
return nil, errors.New("seed engine not initialized")
}
return se, nil
}

// State is a snapshot of the Coordinator's manifest history.
type State struct {
Manifest *manifest.Manifest
Expand Down
23 changes: 3 additions & 20 deletions coordinator/internal/authority/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"net"
"time"

"github.com/edgelesssys/contrast/coordinator/internal/seedengine"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/logger"
Expand All @@ -26,9 +25,8 @@ import (

// Credentials are gRPC transport credentials that dynamically update with the Coordinator state.
type Credentials struct {
issuer atls.Issuer
getState func() (*State, error)
getSeedEngine func() (*seedengine.SeedEngine, error)
issuer atls.Issuer
getState func() (*State, error)

logger *slog.Logger
attestationFailuresCounter prometheus.Counter
Expand Down Expand Up @@ -57,13 +55,6 @@ func (a *Authority) Credentials(reg *prometheus.Registry, issuer atls.Issuer) (*
}
return state, nil
},
getSeedEngine: func() (*seedengine.SeedEngine, error) {
se := a.se.Load()
if se == nil {
return nil, errors.New("seed engine not initialized")
}
return se, nil
},
logger: a.logger,
attestationFailuresCounter: attestationFailuresCounter,
kdsGetter: kdsGetter,
Expand All @@ -79,14 +70,8 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A
return nil, nil, fmt.Errorf("getting state: %w", err)
}

seedEngine, err := c.getSeedEngine()
if err != nil {
return nil, nil, fmt.Errorf("getting seed engine: %w", err)
}

authInfo := AuthInfo{
State: state,
SeedEngine: seedEngine,
State: state,
}

validator := snp.NewValidatorWithCallbacks(state, c.kdsGetter,
Expand Down Expand Up @@ -142,8 +127,6 @@ type AuthInfo struct {
State *State
// Report is the attestation report sent by the peer.
Report *sevsnp.Report
// SeedEngine is the seed engine used to derive secrets.
SeedEngine *seedengine.SeedEngine
}

// ValidateCallback takes the validated report and attaches it to the [AuthInfo].
Expand Down
7 changes: 6 additions & 1 deletion coordinator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/edgelesssys/contrast/coordinator/history"
"github.com/edgelesssys/contrast/coordinator/internal/authority"
"github.com/edgelesssys/contrast/coordinator/internal/seedengine"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/grpc/atlscredentials"
"github.com/edgelesssys/contrast/internal/logger"
Expand Down Expand Up @@ -76,7 +77,11 @@ func run() (retErr error) {

userapi.RegisterUserAPIServer(grpcServer, meshAuth)
serverMetrics.InitializeMetrics(grpcServer)
meshAPI := newMeshAPIServer(meshAuth, promRegistry, serverMetrics, logger)
meshAPI := newMeshAPIServer(meshAuth, promRegistry, serverMetrics,
func() (*seedengine.SeedEngine, error) {
return meshAuth.GetSeedEngine()
},
logger)
metricsServer := &http.Server{}

eg, ctx := errgroup.WithContext(ctx)
Expand Down
25 changes: 17 additions & 8 deletions coordinator/meshapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/edgelesssys/contrast/coordinator/internal/authority"
"github.com/edgelesssys/contrast/coordinator/internal/seedengine"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/manifest"
Expand All @@ -24,14 +25,17 @@ import (
)

type meshAPIServer struct {
grpc *grpc.Server
cleanup func()
logger *slog.Logger
grpc *grpc.Server
cleanup func()
getSeedEngine func() (*seedengine.SeedEngine, error)
logger *slog.Logger

meshapi.UnimplementedMeshAPIServer
}

func newMeshAPIServer(meshAuth *authority.Authority, reg *prometheus.Registry, serverMetrics *grpcprometheus.ServerMetrics, log *slog.Logger) *meshAPIServer {
func newMeshAPIServer(meshAuth *authority.Authority, reg *prometheus.Registry, serverMetrics *grpcprometheus.ServerMetrics,
getSeedEngine func() (*seedengine.SeedEngine, error), log *slog.Logger,
) *meshAPIServer {
credentials, cancel := meshAuth.Credentials(reg, atls.NoIssuer)

grpcServer := grpc.NewServer(
Expand All @@ -45,9 +49,10 @@ func newMeshAPIServer(meshAuth *authority.Authority, reg *prometheus.Registry, s
),
)
s := &meshAPIServer{
grpc: grpcServer,
cleanup: cancel,
logger: log.WithGroup("meshapi"),
grpc: grpcServer,
cleanup: cancel,
getSeedEngine: getSeedEngine,
logger: log.WithGroup("meshapi"),
}
meshapi.RegisterMeshAPIServer(s.grpc, s)
serverMetrics.InitializeMetrics(s.grpc)
Expand Down Expand Up @@ -85,7 +90,11 @@ func (i *meshAPIServer) NewMeshCert(ctx context.Context, _ *meshapi.NewMeshCertR
state := authInfo.State
report := authInfo.Report
tlsInfo := authInfo.TLSInfo
seedEngine := authInfo.SeedEngine

seedEngine, err := i.getSeedEngine()
if err != nil {
return nil, fmt.Errorf("failed to get seed engine: %w", err)
}

if len(tlsInfo.State.PeerCertificates) == 0 {
return nil, fmt.Errorf("no peer certificates found")
Expand Down

0 comments on commit fee113c

Please sign in to comment.