Skip to content

Commit

Permalink
Enable mTLS when using aTLS
Browse files Browse the repository at this point in the history
  • Loading branch information
smithjilks committed Aug 1, 2024
1 parent e376cf3 commit 46624c6
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 66 deletions.
13 changes: 13 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ require (
golang.org/x/sync v0.7.0
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.34.2
github.com/stretchr/testify v1.8.4
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1
go.opentelemetry.io/otel v1.21.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.21.0
go.opentelemetry.io/otel/sdk v1.21.0
go.opentelemetry.io/otel/trace v1.21.0
golang.org/x/sync v0.6.0
google.golang.org/grpc v1.60.1
)

require (
Expand All @@ -33,6 +42,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/go-kit/log v0.2.1 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
Expand All @@ -54,6 +64,8 @@ require (
github.com/stretchr/objx v0.5.2 // indirect
go.opentelemetry.io/otel/metric v1.28.0 // indirect
go.opentelemetry.io/proto/otlp v1.3.1 // indirect
github.com/stretchr/objx v0.5.1 // indirect
go.opentelemetry.io/otel/metric v1.21.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.22.0 // indirect
Expand All @@ -62,4 +74,5 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
golang.org/x/net v0.20.0 // indirect
)
112 changes: 71 additions & 41 deletions internal/server/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (s *Server) Start() error {
creds := grpc.Creds(insecure.NewCredentials())

switch {
case s.Config.AttestedTLS:
case s.Config.AttestedTLS && s.Config.CertFile != "" && s.Config.KeyFile != "":
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.agent)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
Expand All @@ -103,60 +103,40 @@ func (s *Server) Start() error {
return fmt.Errorf("falied due to invalid key pair: %w", err)
}

tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
tlsConfig, err := s.setupTLSConfig()
if err != nil {
return err
}

tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" || s.Config.KeyFile != "":
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)

case s.Config.AttestedTLS:
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.agent)
if err != nil {
return fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
return fmt.Errorf("failed to create certificate: %w", err)
}

var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(s.Config.ServerCAFile)
certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes)
if err != nil {
return fmt.Errorf("failed to load root ca file: %w", err)
return fmt.Errorf("falied due to invalid key pair: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile)

tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
}

// Loading Client CA File
clientCA, err := loadCertFile(s.Config.ClientCAFile)
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" && s.Config.KeyFile != "":
tlsConfig, err := s.setupTLSConfig()
if err != nil {
return fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
return err
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
}

default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
}
Expand Down Expand Up @@ -292,3 +272,53 @@ func generateCertificatesForATLS(svc agent.Service) ([]byte, []byte, error) {

return certBytes, keyBytes, nil
}

func (s *Server) setupTLSConfig() (*tls.Config, error) {
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
}

var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(s.Config.ServerCAFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return &tls.Config{}, fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile)
}

// Loading Client CA File
clientCA, err := loadCertFile(s.Config.ClientCAFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return &tls.Config{}, fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
}
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
}

return tlsConfig, nil
}
93 changes: 68 additions & 25 deletions pkg/clients/grpc/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ const (
withoutTLS security = iota
withTLS
withmTLS
withaTLS
withmaTLS
)

const (
Expand Down Expand Up @@ -122,6 +124,10 @@ func (c *client) Secure() string {
return "with TLS"
case withmTLS:
return "with mTLS"
case withmaTLS:
return "with maTLS"
case withaTLS:
return "with mTLS"
case withoutTLS:
fallthrough
default:
Expand All @@ -141,7 +147,8 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
secure := withoutTLS
tc := insecure.NewCredentials()

if cfg.AttestedTLS {
switch {
case cfg.AttestedTLS && cfg.ServerCAFile != "":
err := ReadManifest(cfg.Manifest, &attestationConfiguration)
if err != nil {
return nil, secure, fmt.Errorf("failed to read Manifest %w", err)
Expand All @@ -151,37 +158,73 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyAttestationReportTLS,
}
tc = credentials.NewTLS(tlsConfig)
} else {
if cfg.ServerCAFile != "" {
tlsConfig := &tls.Config{}

// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withTLS
tlsConfig.RootCAs = capool
secure = withaTLS
}

// Loading mTLS certificates file
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmaTLS
}

tc = credentials.NewTLS(tlsConfig)

case cfg.AttestedTLS:
err := ReadManifest(cfg.Manifest, &attestationConfiguration)
if err != nil {
return nil, secure, fmt.Errorf("failed to read Manifest %w", err)
}

tlsConfig := &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyAttestationReportTLS,
}
tc = credentials.NewTLS(tlsConfig)

case cfg.ServerCAFile != "":
tlsConfig := &tls.Config{}

// Loading mTLS certificates file
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmTLS
// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withTLS
}

tc = credentials.NewTLS(tlsConfig)
// Loading mTLS certificates file
if cfg.ClientCert != "" && cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to load client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmTLS
}
tc = credentials.NewTLS(tlsConfig)
default:
}

opts = append(opts, grpc.WithTransportCredentials(tc))
Expand Down

0 comments on commit 46624c6

Please sign in to comment.