diff --git a/config.go b/config.go index 04a4917658..69b4a54241 100644 --- a/config.go +++ b/config.go @@ -51,6 +51,8 @@ const ( defaultChainSubDirname = "chain" defaultGraphSubDirname = "graph" defaultTowerSubDirname = "watchtower" + defaultCACertFilename = "ca.cert" + defaultCAKeyFilename = "ca.key" defaultTLSCertFilename = "tls.cert" defaultTLSKeyFilename = "tls.key" defaultAdminMacFilename = "admin.macaroon" @@ -264,6 +266,8 @@ var ( defaultTowerDir = filepath.Join(defaultDataDir, defaultTowerSubDirname) + defaultCACertPath = filepath.Join(DefaultLndDir, defaultCACertFilename) + defaultCAKeyPath = filepath.Join(DefaultLndDir, defaultCAKeyFilename) defaultTLSCertPath = filepath.Join(DefaultLndDir, defaultTLSCertFilename) defaultTLSKeyPath = filepath.Join(DefaultLndDir, defaultTLSKeyFilename) defaultLetsEncryptDir = filepath.Join(DefaultLndDir, defaultLetsEncryptDirname) @@ -299,6 +303,8 @@ type Config struct { DataDir string `short:"b" long:"datadir" description:"The directory to store lnd's data within"` SyncFreelist bool `long:"sync-freelist" description:"Whether the databases used within lnd should sync their freelist to disk. This is disabled by default resulting in improved memory performance during operation, but with an increase in startup time."` + CACertPath string `long:"cacertpath" description:"Path to write the CA certificate for lnd's RPC and REST services"` + CAKeyPath string `long:"cakeypath" description:"Path to write the CA private key for lnd's RPC and REST services"` TLSCertPath string `long:"tlscertpath" description:"Path to write the TLS certificate for lnd's RPC and REST services"` TLSKeyPath string `long:"tlskeypath" description:"Path to write the TLS private key for lnd's RPC and REST services"` TLSExtraIPs []string `long:"tlsextraip" description:"Adds an extra ip to the generated certificate"` @@ -556,6 +562,8 @@ func DefaultConfig() Config { ConfigFile: DefaultConfigFile, DataDir: defaultDataDir, DebugLevel: defaultLogLevel, + CACertPath: defaultCACertPath, + CAKeyPath: defaultCAKeyPath, TLSCertPath: defaultTLSCertPath, TLSKeyPath: defaultTLSKeyPath, TLSCertDuration: defaultTLSCertDuration, @@ -880,6 +888,8 @@ func ValidateConfig(cfg Config, interceptor signal.Interceptor, fileParser, cfg.LetsEncryptDir = filepath.Join( lndDir, defaultLetsEncryptDirname, ) + cfg.CACertPath = filepath.Join(lndDir, defaultCACertFilename) + cfg.CAKeyPath = filepath.Join(lndDir, defaultCAKeyFilename) cfg.TLSCertPath = filepath.Join(lndDir, defaultTLSCertFilename) cfg.TLSKeyPath = filepath.Join(lndDir, defaultTLSKeyFilename) cfg.LogDir = filepath.Join(lndDir, defaultLogDirname) @@ -962,6 +972,8 @@ func ValidateConfig(cfg Config, interceptor signal.Interceptor, fileParser, // to directories and files are cleaned and expanded before attempting // to use them later on. cfg.DataDir = CleanAndExpandPath(cfg.DataDir) + cfg.CACertPath = CleanAndExpandPath(cfg.CACertPath) + cfg.CAKeyPath = CleanAndExpandPath(cfg.CAKeyPath) cfg.TLSCertPath = CleanAndExpandPath(cfg.TLSCertPath) cfg.TLSKeyPath = CleanAndExpandPath(cfg.TLSKeyPath) cfg.LetsEncryptDir = CleanAndExpandPath(cfg.LetsEncryptDir) @@ -1382,6 +1394,7 @@ func ValidateConfig(cfg Config, interceptor signal.Interceptor, fileParser, dirs := []string{ lndDir, cfg.DataDir, cfg.networkDir, cfg.LetsEncryptDir, towerDir, cfg.graphDatabaseDir(), + filepath.Dir(cfg.CACertPath), filepath.Dir(cfg.CAKeyPath), filepath.Dir(cfg.TLSCertPath), filepath.Dir(cfg.TLSKeyPath), filepath.Dir(cfg.AdminMacPath), filepath.Dir(cfg.ReadMacPath), filepath.Dir(cfg.InvoiceMacPath), diff --git a/lnd.go b/lnd.go index f511811950..887388f738 100644 --- a/lnd.go +++ b/lnd.go @@ -265,6 +265,8 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, } tlsManagerCfg := &TLSManagerCfg{ + CACertPath: cfg.CACertPath, + CAKeyPath: cfg.CAKeyPath, TLSCertPath: cfg.TLSCertPath, TLSKeyPath: cfg.TLSKeyPath, TLSEncryptKey: cfg.TLSEncryptKey, diff --git a/tls_manager.go b/tls_manager.go index 076cf44bc8..dfc9e95237 100644 --- a/tls_manager.go +++ b/tls_manager.go @@ -43,6 +43,8 @@ var ( // TLSManagerCfg houses a set of values and methods that is passed to the // TLSManager for it to properly manage LND's TLS options. type TLSManagerCfg struct { + CACertPath string + CAKeyPath string TLSCertPath string TLSKeyPath string TLSEncryptKey bool @@ -176,7 +178,7 @@ func (t *TLSManager) getConfig() ([]grpc.ServerOption, []grpc.DialOption, func (t *TLSManager) generateOrRenewCert() (*tls.Config, error) { // Generete a TLS pair if we don't have one yet. var emptyKeyRing keychain.SecretKeyRing - err := t.generateCertPair(emptyKeyRing) + err := t.generateCerts(emptyKeyRing) if err != nil { return nil, err } @@ -188,9 +190,18 @@ func (t *TLSManager) generateOrRenewCert() (*tls.Config, error) { return nil, err } + _, parsedCaCert, err := cert.LoadCert( + t.cfg.CACertPath, t.cfg.CAKeyPath, + ) + if err != nil { + rpcsLog.Warnf("Failed to load CA certficate. This could " + + "trigger certificate renewal.") + } + // Check to see if the certificate needs to be renewed. If it does, we // return the newly generated certificate data instead. - reloadedCertData, err := t.maintainCert(parsedCert) + reloadedCertData, err := t.maintainCerts(parsedCaCert, parsedCert, + emptyKeyRing) if err != nil { return nil, err } @@ -203,24 +214,37 @@ func (t *TLSManager) generateOrRenewCert() (*tls.Config, error) { return tlsCfg, nil } -// generateCertPair creates and writes a TLS pair to disk if the pair -// doesn't exist yet. If the TLSEncryptKey setting is on, and a plaintext key -// is already written to disk, this function overwrites the plaintext key with -// the encrypted form. -func (t *TLSManager) generateCertPair(keyRing keychain.SecretKeyRing) error { +// generateCerts creates and writes a CA pair and TLS pair to disk if the pairs +// don't exist yet. If the TLSEncryptKey setting is on, and a plaintext key is +// already written to disk, this function overwrites the plaintext key with the +// encrypted form. +func (t *TLSManager) generateCerts(keyRing keychain.SecretKeyRing) error { // Ensure we create TLS key and certificate if they don't exist. if lnrpc.FileExists(t.cfg.TLSCertPath) || lnrpc.FileExists(t.cfg.TLSKeyPath) { // Handle discrepencies related to the TLSEncryptKey setting. - return t.ensureEncryption(keyRing) + return t.ensureEncryption(keyRing, t.cfg.TLSCertPath, + t.cfg.TLSKeyPath) } rpcsLog.Infof("Generating TLS certificates...") + + // Always generate a new CA, regardless of whether it existed previously + // or not. + caBytes, caKeyBytes, err := cert.GenCertPair( + "lnd autogenerated ca cert", t.cfg.TLSExtraIPs, + t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, + t.cfg.TLSCertDuration, nil, nil, + ) + if err != nil { + return err + } + certBytes, keyBytes, err := cert.GenCertPair( "lnd autogenerated cert", t.cfg.TLSExtraIPs, t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, - t.cfg.TLSCertDuration, + t.cfg.TLSCertDuration, caBytes, caKeyBytes, ) if err != nil { return err @@ -234,6 +258,16 @@ func (t *TLSManager) generateCertPair(keyRing keychain.SecretKeyRing) error { "encrypt key %v", err) } + err = e.EncryptPayloadToWriter( + caKeyBytes, &b, + ) + if err != nil { + return err + } + + caKeyBytes = b.Bytes() + + b.Reset() err = e.EncryptPayloadToWriter( keyBytes, &b, ) @@ -244,6 +278,13 @@ func (t *TLSManager) generateCertPair(keyRing keychain.SecretKeyRing) error { keyBytes = b.Bytes() } + err = cert.WriteCertPair( + t.cfg.CACertPath, t.cfg.CAKeyPath, caBytes, caKeyBytes, + ) + if err != nil { + return err + } + err = cert.WriteCertPair( t.cfg.TLSCertPath, t.cfg.TLSKeyPath, certBytes, keyBytes, ) @@ -258,9 +299,11 @@ func (t *TLSManager) generateCertPair(keyRing keychain.SecretKeyRing) error { // encrypt the file and rewrite it to disk. // 2) On the flip side, if TLSEncryptKey is not set, but the key on disk // is encrypted, we need to error out and warn the user. -func (t *TLSManager) ensureEncryption(keyRing keychain.SecretKeyRing) error { +func (t *TLSManager) ensureEncryption(keyRing keychain.SecretKeyRing, + certPath string, keyPath string) error { + _, keyBytes, err := cert.GetCertBytesFromPath( - t.cfg.TLSCertPath, t.cfg.TLSKeyPath, + certPath, keyPath, ) if err != nil { return err @@ -279,7 +322,7 @@ func (t *TLSManager) ensureEncryption(keyRing keychain.SecretKeyRing) error { return err } err = os.WriteFile( - t.cfg.TLSKeyPath, b.Bytes(), modifyFilePermissions, + keyPath, b.Bytes(), modifyFilePermissions, ) if err != nil { return err @@ -323,11 +366,12 @@ func decryptTLSKeyBytes(keyRing keychain.SecretKeyRing, return plaintext, nil } -// maintainCert checks if the certificate IP and domains matches the config, +// maintainCerts checks if the certificate IP and domains matches the config, // and renews the certificate if either this data is outdated or the // certificate is expired. -func (t *TLSManager) maintainCert( - parsedCert *x509.Certificate) (*tls.Certificate, error) { +func (t *TLSManager) maintainCerts(parsedCaCert *x509.Certificate, + parsedCert *x509.Certificate, keyRing keychain.SecretKeyRing, +) (*tls.Certificate, error) { // We check whether the certificate we have on disk match the IPs and // domains specified by the config. If the extra IPs or domains have @@ -336,47 +380,42 @@ func (t *TLSManager) maintainCert( refresh := false var err error if t.cfg.TLSAutoRefresh { - refresh, err = cert.IsOutdated( - parsedCert, t.cfg.TLSExtraIPs, - t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, - ) + refresh, err = t.shouldRefresh(parsedCaCert, parsedCert) if err != nil { return nil, err } } - // If the certificate expired or it was outdated, delete it and the TLS - // key and generate a new pair. - if !time.Now().After(parsedCert.NotAfter) && !refresh { + if !refresh { return nil, nil } + // If the certificate expired or it was outdated, delete it and the TLS + // key and generate a new pair. ltndLog.Info("TLS certificate is expired or outdated, " + "generating a new one") - err = os.Remove(t.cfg.TLSCertPath) + err = os.Remove(t.cfg.CACertPath) if err != nil { return nil, err } - err = os.Remove(t.cfg.TLSKeyPath) + err = os.Remove(t.cfg.CAKeyPath) if err != nil { return nil, err } - rpcsLog.Infof("Renewing TLS certificates...") - certBytes, keyBytes, err := cert.GenCertPair( - "lnd autogenerated cert", t.cfg.TLSExtraIPs, - t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, - t.cfg.TLSCertDuration, - ) + err = os.Remove(t.cfg.TLSCertPath) if err != nil { return nil, err } - err = cert.WriteCertPair( - t.cfg.TLSCertPath, t.cfg.TLSKeyPath, certBytes, keyBytes, - ) + err = os.Remove(t.cfg.TLSKeyPath) + if err != nil { + return nil, err + } + + err = t.generateCerts(keyRing) if err != nil { return nil, err } @@ -391,6 +430,47 @@ func (t *TLSManager) maintainCert( return &reloadedCertData, err } +func (t *TLSManager) shouldRefresh(parsedCaCert *x509.Certificate, + parsedCert *x509.Certificate) (bool, error) { + + if parsedCaCert == nil { + return true, nil + } + + if parsedCert == nil { + return true, nil + } + + refresh, err := cert.IsOutdated( + parsedCaCert, t.cfg.TLSExtraIPs, + t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, + ) + if err != nil { + return false, err + } + if refresh { + return true, nil + } + + refresh, err = cert.IsOutdated( + parsedCert, t.cfg.TLSExtraIPs, + t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, + ) + if err != nil { + return false, err + } + if refresh { + return true, nil + } + + if time.Now().After(parsedCert.NotAfter) || + time.Now().After(parsedCaCert.NotAfter) { + return true, nil + } + + return false, nil +} + // setUpLetsEncrypt automatically generates a Let's Encrypt certificate if the // option is set. func (t *TLSManager) setUpLetsEncrypt(certData *tls.Certificate, @@ -507,6 +587,7 @@ func (t *TLSManager) loadEphemeralCertificate() ([]byte, error) { certBytes, keyBytes, err := cert.GenCertPair( "lnd ephemeral autogenerated cert", t.cfg.TLSExtraIPs, t.cfg.TLSExtraDomains, t.cfg.TLSDisableAutofill, tmpValidity, + nil, nil, ) if err != nil { return nil, err @@ -539,7 +620,7 @@ func (t *TLSManager) LoadPermanentCertificate( tmpCertPath) } - err = t.generateCertPair(keyRing) + err = t.generateCerts(keyRing) if err != nil { return err } diff --git a/tls_manager_test.go b/tls_manager_test.go index 42f010411b..88c6a5f32e 100644 --- a/tls_manager_test.go +++ b/tls_manager_test.go @@ -114,7 +114,7 @@ func TestTLSManagerGenCert(t *testing.T) { RootKey: privKey, } - err = tlsManager.generateCertPair(keyRing) + err = tlsManager.generateCerts(keyRing) require.NoError(t, err, "failed to generate new certificate") _, keyBytes, err = cert.GetCertBytesFromPath( @@ -160,7 +160,7 @@ func TestEnsureEncryption(t *testing.T) { // ensureEncryption should detect that the TLS key is in plaintext, // encrypt it, and rewrite the encrypted version to disk. - err = tlsManager.ensureEncryption(keyRing) + err = tlsManager.ensureEncryption(keyRing, certPath, keyPath) require.NoError(t, err, "failed to generate new certificate") // Grab the file from disk to check that the key is no longer @@ -175,7 +175,7 @@ func TestEnsureEncryption(t *testing.T) { // Now let's flip the cfg.TLSEncryptKey to false. Since the key on file // is encrypted, ensureEncryption should error out. tlsManager.cfg.TLSEncryptKey = false - err = tlsManager.ensureEncryption(keyRing) + err = tlsManager.ensureEncryption(keyRing, certPath, keyPath) require.Error(t, err) }