Skip to content

Commit

Permalink
Enable mTLS when using aTLS
Browse files Browse the repository at this point in the history
  • Loading branch information
smithjilks committed Jul 10, 2024
1 parent 4c4161c commit 4968ec5
Showing 1 changed file with 71 additions and 41 deletions.
112 changes: 71 additions & 41 deletions internal/server/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (s *Server) Start() error {
creds := grpc.Creds(insecure.NewCredentials())

switch {
case s.Config.AttestedTLS:
case s.Config.AttestedTLS && s.Config.CertFile != "" && s.Config.KeyFile != "":
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.agent)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
Expand All @@ -101,60 +101,40 @@ func (s *Server) Start() error {
return fmt.Errorf("falied due to invalid key pair: %w", err)
}

tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
tlsConfig, err := s.setupTLSConfig()
if err != nil {
return err
}

tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" || s.Config.KeyFile != "":
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)

case s.Config.AttestedTLS:
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.agent)
if err != nil {
return fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
return fmt.Errorf("failed to create certificate: %w", err)
}

var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(s.Config.ServerCAFile)
certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes)
if err != nil {
return fmt.Errorf("failed to load root ca file: %w", err)
return fmt.Errorf("falied due to invalid key pair: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile)

tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
}

// Loading Client CA File
clientCA, err := loadCertFile(s.Config.ClientCAFile)
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" && s.Config.KeyFile != "":
tlsConfig, err := s.setupTLSConfig()
if err != nil {
return fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
return err
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
}

default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
}
Expand Down Expand Up @@ -290,3 +270,53 @@ func generateCertificatesForATLS(svc agent.Service) ([]byte, []byte, error) {

return certBytes, keyBytes, nil
}

func (s *Server) setupTLSConfig() (*tls.Config, error) {
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
}

var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(s.Config.ServerCAFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return &tls.Config{}, fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile)
}

// Loading Client CA File
clientCA, err := loadCertFile(s.Config.ClientCAFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return &tls.Config{}, fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
}
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
}

return tlsConfig, nil
}

0 comments on commit 4968ec5

Please sign in to comment.