diff --git a/agent/config/config.go b/agent/config/config.go index 8c216a97..f382c12b 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -128,6 +128,13 @@ func (c *ListenerConfig) Validate() error { } type TLSConfig struct { + // Cert contains a path to the PEM encoded certificate to present to + // the server (optional). + Cert string `json:"cert" yaml:"cert"` + + // Key contains a path to the PEM encoded private key (optional). + Key string `json:"key" yaml:"key"` + // RootCAs contains a path to root certificate authorities to validate // the TLS connection to the Piko server. // @@ -142,12 +149,31 @@ type TLSConfig struct { } func (c *TLSConfig) Validate() error { + if c.Cert != "" && c.Key == "" { + return fmt.Errorf("missing key") + } + _, err := c.Load() return err } func (c *TLSConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) { prefix = prefix + ".tls." + fs.StringVar( + &c.Cert, + prefix+"cert", + c.Cert, + ` +Path to the PEM encoded certificate file to present to the server.`, + ) + fs.StringVar( + &c.Key, + prefix+"key", + c.Key, + ` +Path to the PEM encoded key file.`, + ) + fs.StringVar( &c.RootCAs, prefix+"root-cas", @@ -172,6 +198,14 @@ host name in that certificate.`, func (c *TLSConfig) Load() (*tls.Config, error) { tlsConfig := &tls.Config{} + if c.Cert != "" { + cert, err := tls.LoadX509KeyPair(c.Cert, c.Key) + if err != nil { + return nil, fmt.Errorf("load key pair: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + if c.RootCAs != "" { caCert, err := os.ReadFile(c.RootCAs) if err != nil { diff --git a/forward/config/config.go b/forward/config/config.go index ab26daa5..a6a69936 100644 --- a/forward/config/config.go +++ b/forward/config/config.go @@ -57,6 +57,13 @@ func (c *PortConfig) Validate() error { } type TLSConfig struct { + // Cert contains a path to the PEM encoded certificate to present to + // the server (optional). + Cert string `json:"cert" yaml:"cert"` + + // Key contains a path to the PEM encoded private key (optional). + Key string `json:"key" yaml:"key"` + // RootCAs contains a path to root certificate authorities to validate // the TLS connection to the Piko server. // @@ -71,12 +78,31 @@ type TLSConfig struct { } func (c *TLSConfig) Validate() error { + if c.Cert != "" && c.Key == "" { + return fmt.Errorf("missing key") + } + _, err := c.Load() return err } func (c *TLSConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) { prefix = prefix + ".tls." + fs.StringVar( + &c.Cert, + prefix+"cert", + c.Cert, + ` +Path to the PEM encoded certificate file to present to the server.`, + ) + fs.StringVar( + &c.Key, + prefix+"key", + c.Key, + ` +Path to the PEM encoded key file.`, + ) + fs.StringVar( &c.RootCAs, prefix+"root-cas", @@ -101,6 +127,14 @@ host name in that certificate.`, func (c *TLSConfig) Load() (*tls.Config, error) { tlsConfig := &tls.Config{} + if c.Cert != "" { + cert, err := tls.LoadX509KeyPair(c.Cert, c.Key) + if err != nil { + return nil, fmt.Errorf("load key pair: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + if c.RootCAs != "" { caCert, err := os.ReadFile(c.RootCAs) if err != nil { diff --git a/server/config/tls.go b/server/config/tls.go index 14196ba6..3b1832b0 100644 --- a/server/config/tls.go +++ b/server/config/tls.go @@ -2,14 +2,17 @@ package config import ( "crypto/tls" + "crypto/x509" "fmt" + "os" "github.com/spf13/pflag" ) type TLSConfig struct { - Cert string `json:"cert" yaml:"cert"` - Key string `json:"key" yaml:"key"` + Cert string `json:"cert" yaml:"cert"` + Key string `json:"key" yaml:"key"` + ClientCAs string `json:"client_cas" yaml:"client_cas"` } func (c *TLSConfig) Validate() error { @@ -45,6 +48,16 @@ If given the server will listen on TLS`, ` Path to the PEM encoded key file.`, ) + fs.StringVar( + &c.ClientCAs, + prefix+"client-cas", + c.ClientCAs, + ` +A path to a certificate PEM file containing client certificiate authorities to +verify the client certificates. + +When set the client must set a valid certificate during the TLS handshake.`, + ) } func (c *TLSConfig) Load() (*tls.Config, error) { @@ -59,6 +72,20 @@ func (c *TLSConfig) Load() (*tls.Config, error) { } tlsConfig.Certificates = []tls.Certificate{cert} + if c.ClientCAs != "" { + caCert, err := os.ReadFile(c.ClientCAs) + if err != nil { + return nil, fmt.Errorf("open client cas: %s: %w", c.ClientCAs, err) + } + caCertPool := x509.NewCertPool() + ok := caCertPool.AppendCertsFromPEM(caCert) + if !ok { + return nil, fmt.Errorf("parse client cas: %s: %w", c.ClientCAs, err) + } + tlsConfig.ClientCAs = caCertPool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + return tlsConfig, nil }