Skip to content

Commit

Permalink
Enable mTLS when using aTLS
Browse files Browse the repository at this point in the history
  • Loading branch information
smithjilks committed Jul 11, 2024
1 parent 4c4161c commit c4b8571
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 66 deletions.
112 changes: 71 additions & 41 deletions internal/server/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (s *Server) Start() error {
creds := grpc.Creds(insecure.NewCredentials())

switch {
case s.Config.AttestedTLS:
case s.Config.AttestedTLS && s.Config.CertFile != "" && s.Config.KeyFile != "":
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.agent)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
Expand All @@ -101,60 +101,40 @@ func (s *Server) Start() error {
return fmt.Errorf("falied due to invalid key pair: %w", err)
}

tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
tlsConfig, err := s.setupTLSConfig()
if err != nil {
return err
}

tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" || s.Config.KeyFile != "":
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)

case s.Config.AttestedTLS:
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.agent)
if err != nil {
return fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
return fmt.Errorf("failed to create certificate: %w", err)
}

var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(s.Config.ServerCAFile)
certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes)
if err != nil {
return fmt.Errorf("failed to load root ca file: %w", err)
return fmt.Errorf("falied due to invalid key pair: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile)

tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
}

// Loading Client CA File
clientCA, err := loadCertFile(s.Config.ClientCAFile)
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" && s.Config.KeyFile != "":
tlsConfig, err := s.setupTLSConfig()
if err != nil {
return fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
return err
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
}

default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
}
Expand Down Expand Up @@ -290,3 +270,53 @@ func generateCertificatesForATLS(svc agent.Service) ([]byte, []byte, error) {

return certBytes, keyBytes, nil
}

func (s *Server) setupTLSConfig() (*tls.Config, error) {
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
}

var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(s.Config.ServerCAFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return &tls.Config{}, fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile)
}

// Loading Client CA File
clientCA, err := loadCertFile(s.Config.ClientCAFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return &tls.Config{}, fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
}
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
}

return tlsConfig, nil
}
94 changes: 69 additions & 25 deletions pkg/clients/grpc/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ const (
withoutTLS security = iota
withTLS
withmTLS
withaTLS
withmaTLS
)

var (
Expand Down Expand Up @@ -112,6 +114,10 @@ func (c *client) Secure() string {
return "with TLS"
case withmTLS:
return "with mTLS"
case withmaTLS:
return "with maTLS"
case withaTLS:
return "with mTLS"
case withoutTLS:
fallthrough
default:
Expand All @@ -131,7 +137,8 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
secure := withoutTLS
tc := insecure.NewCredentials()

if cfg.AttestedTLS {
switch {
case cfg.AttestedTLS && cfg.ServerCAFile != "":
err := readManifest(cfg)
if err != nil {
return nil, secure, fmt.Errorf("failed to read Manifest %w", err)
Expand All @@ -141,37 +148,74 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyAttestationReportTLS,
}
tc = credentials.NewTLS(tlsConfig)
} else {
if cfg.ServerCAFile != "" {
tlsConfig := &tls.Config{}

// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withTLS
tlsConfig.RootCAs = capool
secure = withaTLS
}

// Loading mTLS certificates file
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmaTLS
}

tc = credentials.NewTLS(tlsConfig)

case cfg.AttestedTLS:
err := readManifest(cfg)
if err != nil {
return nil, secure, fmt.Errorf("failed to read Manifest %w", err)
}

// Loading mTLS certificates file
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmTLS
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyAttestationReportTLS,
}
tc = credentials.NewTLS(tlsConfig)

case cfg.ServerCAFile != "":
tlsConfig := &tls.Config{}

// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withTLS
}

tc = credentials.NewTLS(tlsConfig)
// Loading mTLS certificates file
if cfg.ClientCert != "" && cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to load client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmTLS
}
tc = credentials.NewTLS(tlsConfig)
default:

}

opts = append(opts, grpc.WithTransportCredentials(tc))
Expand Down

0 comments on commit c4b8571

Please sign in to comment.