From aa210f80e29a120a662b7e06665fa1e1fbc962ca Mon Sep 17 00:00:00 2001 From: Ruben Vargas Date: Wed, 17 Jul 2024 03:24:18 -0600 Subject: [PATCH] Add watcher for upstreams TLS certificates (#716) * Add watcher to TLS certificates Signed-off-by: Ruben Vargas * Fix CA loading logic Signed-off-by: Ruben Vargas * Add some comments, change parameters order Signed-off-by: Ruben Vargas * Fix some comments and clarify some names Signed-off-by: Ruben Vargas --------- Signed-off-by: Ruben Vargas --- README.md | 6 + api/logs/v1/http.go | 11 +- api/metrics/legacy/http.go | 5 +- api/metrics/v1/http.go | 11 +- api/traces/v1/api.go | 17 +-- api/traces/v1/http.go | 9 +- go.mod | 2 +- main.go | 151 +++++++++++---------- tls/ca_watcher.go | 112 ++++++++++++++++ tls/ca_watcher_test.go | 123 +++++++++++++++++ tls/config.go | 19 --- tls/options.go | 159 ++++++++++++++++++++++ tls/options_test.go | 268 +++++++++++++++++++++++++++++++++++++ 13 files changed, 772 insertions(+), 121 deletions(-) create mode 100644 tls/ca_watcher.go create mode 100644 tls/ca_watcher_test.go create mode 100644 tls/options.go create mode 100644 tls/options_test.go diff --git a/README.md b/README.md index b95fa4e71..62cd466ba 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,8 @@ Usage of ./observatorium-api: File containing the TLS client certificates to authenticate against upstream logs servers. Leave blank to disable mTLS. -logs.tls.key-file string File containing the TLS client key to authenticate against upstream logs servers. Leave blank to disable mTLS. + -logs.tls.watch-certs + Watch for certificate changes and reload -logs.write-timeout duration The HTTP write timeout for proxied requests to the logs endpoint. (default 10m0s) -logs.write.endpoint string @@ -133,6 +135,8 @@ Usage of ./observatorium-api: File containing the TLS client certificates to authenticate against upstream logs servers. Leave blank to disable mTLS. -metrics.tls.key-file string File containing the TLS client key to authenticate against upstream metrics servers. Leave blank to disable mTLS. + -metrics.tls.watch-certs + Watch for certificate changes and reload -metrics.write-timeout duration The HTTP write timeout for proxied requests to the metrics endpoint. (default 2m0s) -metrics.write.endpoint string @@ -193,6 +197,8 @@ Usage of ./observatorium-api: File containing the TLS client certificates to authenticate against upstream logs servers. Leave blank to disable mTLS. -traces.tls.key-file string File containing the TLS client key to authenticate against upstream traces servers. Leave blank to disable mTLS. + -traces.tls.watch-certs + Watch for certificate changes and reload -traces.write-timeout duration The HTTP write timeout for proxied requests to the traces endpoint. (default 2m0s) -traces.write.otlpgrpc.endpoint string diff --git a/api/logs/v1/http.go b/api/logs/v1/http.go index 21f38f63e..e853103c9 100644 --- a/api/logs/v1/http.go +++ b/api/logs/v1/http.go @@ -2,7 +2,6 @@ package http import ( - stdtls "crypto/tls" "net" "net/http" "net/http/httputil" @@ -147,7 +146,7 @@ func (n nopInstrumentHandler) NewHandler(labels prometheus.Labels, handler http. return handler.ServeHTTP } -func NewHandler(read, tail, write, rules *url.URL, rulesReadOnly bool, upstreamCA []byte, upstreamCert *stdtls.Certificate, opts ...HandlerOption) http.Handler { +func NewHandler(read, tail, write, rules *url.URL, rulesReadOnly bool, tlsOptions *tls.UpstreamOptions, opts ...HandlerOption) http.Handler { c := &handlerConfiguration{ logger: log.NewNopLogger(), registry: prometheus.NewRegistry(), @@ -174,7 +173,7 @@ func NewHandler(read, tail, write, rules *url.URL, rulesReadOnly bool, upstreamC DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } proxyRead = &httputil.ReverseProxy{ @@ -250,7 +249,7 @@ func NewHandler(read, tail, write, rules *url.URL, rulesReadOnly bool, upstreamC DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } proxyReadRules = &httputil.ReverseProxy{ @@ -350,7 +349,7 @@ func NewHandler(read, tail, write, rules *url.URL, rulesReadOnly bool, upstreamC DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } tailRead = &httputil.ReverseProxy{ @@ -386,7 +385,7 @@ func NewHandler(read, tail, write, rules *url.URL, rulesReadOnly bool, upstreamC DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } proxyWrite = &httputil.ReverseProxy{ diff --git a/api/metrics/legacy/http.go b/api/metrics/legacy/http.go index 0e83192ef..ddc150ad3 100644 --- a/api/metrics/legacy/http.go +++ b/api/metrics/legacy/http.go @@ -1,7 +1,6 @@ package legacy import ( - stdtls "crypto/tls" "net" "net/http" "net/http/httputil" @@ -102,7 +101,7 @@ func (n nopInstrumentHandler) NewHandler(_ prometheus.Labels, handler http.Handl return handler.ServeHTTP } -func NewHandler(url *url.URL, upstreamCA []byte, upstreamCert *stdtls.Certificate, opts ...HandlerOption) http.Handler { +func NewHandler(url *url.URL, tlsOptions *tls.UpstreamOptions, opts ...HandlerOption) http.Handler { c := &handlerConfiguration{ logger: log.NewNopLogger(), registry: prometheus.NewRegistry(), @@ -130,7 +129,7 @@ func NewHandler(url *url.URL, upstreamCA []byte, upstreamCert *stdtls.Certificat DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } legacyProxy = &httputil.ReverseProxy{ diff --git a/api/metrics/v1/http.go b/api/metrics/v1/http.go index d85433b6e..243a0fe5e 100644 --- a/api/metrics/v1/http.go +++ b/api/metrics/v1/http.go @@ -1,7 +1,6 @@ package v1 import ( - stdtls "crypto/tls" "net" "net/http" "net/http/httputil" @@ -174,7 +173,7 @@ type Endpoints struct { // NewHandler creates the new metrics v1 handler. // nolint:funlen -func NewHandler(endpoints Endpoints, upstreamCA []byte, upstreamCert *stdtls.Certificate, opts ...HandlerOption) http.Handler { +func NewHandler(endpoints Endpoints, tlsOptions *tls.UpstreamOptions, opts ...HandlerOption) http.Handler { c := &handlerConfiguration{ logger: log.NewNopLogger(), registry: prometheus.NewRegistry(), @@ -258,7 +257,7 @@ func NewHandler(endpoints Endpoints, upstreamCA []byte, upstreamCert *stdtls.Cer DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } proxyRead = &httputil.ReverseProxy{ @@ -345,7 +344,7 @@ func NewHandler(endpoints Endpoints, upstreamCA []byte, upstreamCert *stdtls.Cer ) t := http.DefaultTransport.(*http.Transport) - t.TLSClientConfig = tls.NewClientConfig(upstreamCA, upstreamCert) + t.TLSClientConfig = tlsOptions.NewClientConfig() uiProxy = &httputil.ReverseProxy{ Director: middlewares, @@ -384,7 +383,7 @@ func NewHandler(endpoints Endpoints, upstreamCA []byte, upstreamCert *stdtls.Cer DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } proxyWrite = &httputil.ReverseProxy{ @@ -469,7 +468,7 @@ func NewHandler(endpoints Endpoints, upstreamCA []byte, upstreamCert *stdtls.Cer DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } proxyAlertmanager = &httputil.ReverseProxy{ diff --git a/api/traces/v1/api.go b/api/traces/v1/api.go index 9cb7f5a74..e30789695 100644 --- a/api/traces/v1/api.go +++ b/api/traces/v1/api.go @@ -2,7 +2,6 @@ package v1 import ( "context" - stdtls "crypto/tls" "time" "github.com/go-kit/log" @@ -18,9 +17,8 @@ import ( const TraceRoute = "/opentelemetry.proto.collector.trace.v1.TraceService/Export" type connOptions struct { - logger log.Logger - tracesUpstreamCert *stdtls.Certificate - tracesUpstreamCA []byte + logger log.Logger + tlsOptions *tls.UpstreamOptions } // ClientOption modifies the connection's configuration. @@ -33,15 +31,14 @@ func WithLogger(logger log.Logger) ClientOption { } } -func WithUpstreamTLS(tracesUpstreamCA []byte, tracesUpstreamCert *stdtls.Certificate) ClientOption { +func WithUpstreamTLSOptions(tlsOptions *tls.UpstreamOptions) ClientOption { return func(h *connOptions) { - h.tracesUpstreamCA = tracesUpstreamCA - h.tracesUpstreamCert = tracesUpstreamCert + h.tlsOptions = tlsOptions } } -func newCredentials(upstreamCA []byte, upstreamCert *stdtls.Certificate) credentials.TransportCredentials { - tlsConfig := tls.NewClientConfig(upstreamCA, upstreamCert) +func newCredentials(tlsOptions *tls.UpstreamOptions) credentials.TransportCredentials { + tlsConfig := tlsOptions.NewClientConfig() if tlsConfig == nil { return insecure.NewCredentials() } @@ -70,5 +67,5 @@ func NewOTelConnection(write string, opts ...ClientOption) (*grpc.ClientConn, er // because the codec we need to register is also deprecated. A better fix, is the newer // version of mwitkow/grpc-proxy, but that version doesn't (currently) work with OTel protocol. grpc.WithCodec(grpcproxy.Codec()), // nolint: staticcheck - grpc.WithTransportCredentials(newCredentials(c.tracesUpstreamCA, c.tracesUpstreamCert))) + grpc.WithTransportCredentials(newCredentials(c.tlsOptions))) } diff --git a/api/traces/v1/http.go b/api/traces/v1/http.go index 82f22753f..7055b5130 100644 --- a/api/traces/v1/http.go +++ b/api/traces/v1/http.go @@ -4,7 +4,6 @@ import ( "bytes" "compress/flate" "compress/gzip" - stdtls "crypto/tls" "fmt" "io" "net" @@ -109,7 +108,7 @@ func (n nopInstrumentHandler) NewHandler(labels prometheus.Labels, handler http. // The web UI handler is able to rewrite // HTML to change the attribute so that it works with the Observatorium-style // "/api/v1/traces/{tenant}/" URLs. -func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPHttp *url.URL, upstreamCA []byte, upstreamCert *stdtls.Certificate, opts ...HandlerOption) http.Handler { +func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPHttp *url.URL, tlsOptions *tls.UpstreamOptions, opts ...HandlerOption) http.Handler { if read == nil && readTemplate == "" && tempo == nil { panic("missing Jaeger read url") @@ -152,7 +151,7 @@ func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPH DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } proxyRead = &httputil.ReverseProxy{ @@ -203,7 +202,7 @@ func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPH DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } proxyOTLP := &httputil.ReverseProxy{ @@ -229,7 +228,7 @@ func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPH DialContext: (&net.Dialer{ Timeout: dialTimeout, }).DialContext, - TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert), + TLSClientConfig: tlsOptions.NewClientConfig(), } middlewares := proxy.Middlewares( diff --git a/go.mod b/go.mod index 99c8f469f..3006513a7 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/prometheus/common v0.53.0 github.com/prometheus/prometheus v0.50.1 github.com/redis/rueidis v1.0.37 + github.com/stretchr/testify v1.9.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0 go.opentelemetry.io/contrib/propagators/jaeger v1.28.0 go.opentelemetry.io/otel v1.28.0 @@ -159,7 +160,6 @@ require ( github.com/schollz/closestmatch v2.1.0+incompatible // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/testify v1.9.0 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/tdewolff/minify/v2 v2.12.9 // indirect github.com/tdewolff/parse/v2 v2.6.8 // indirect diff --git a/main.go b/main.go index 9ceaeac19..3274b09e7 100644 --- a/main.go +++ b/main.go @@ -152,7 +152,8 @@ type metricsConfig struct { tenantHeader string tenantLabel string // enable metrics if at least one {read|write}Endpoint} is provided. - enabled bool + enabled bool + enableCertWatcher bool } type logsConfig struct { @@ -171,7 +172,8 @@ type logsConfig struct { rulesLabelFilters map[string][]string authExtractSelectors []string // enable logs at least one {read,write,tail}Endpoint} is provided. - enabled bool + enabled bool + enableCertWatcher bool } type tracesConfig struct { @@ -188,7 +190,8 @@ type tracesConfig struct { upstreamKeyFile string tenantHeader string // enable traces if readTemplateEndpoint, readEndpoint, or writeEndpoint is provided. - enabled bool + enabled bool + enableCertWatcher bool } type middlewareConfig struct { @@ -500,62 +503,6 @@ func main() { ) } - var ( - metricsUpstreamCACert []byte - metricsUpstreamClientCert *stdtls.Certificate - logsUpstreamCACert []byte - logsUpstreamClientCert *stdtls.Certificate - tracesUpstreamCACert []byte - tracesUpstreamClientCert *stdtls.Certificate - ) - - if cfg.metrics.upstreamCAFile != "" { - metricsUpstreamCACert, err = os.ReadFile(cfg.metrics.upstreamCAFile) - if err != nil { - stdlog.Fatalf("failed to read upstream metrics TLS CA: %v", err) - } - } - - if cfg.metrics.upstreamCertFile != "" && cfg.metrics.upstreamKeyFile != "" { - clientCert, err := stdtls.LoadX509KeyPair(cfg.metrics.upstreamCertFile, cfg.metrics.upstreamKeyFile) - if err != nil { - stdlog.Fatalf("failed to read upstream metrics client TLS cert/key pair: %v", err) - } - metricsUpstreamClientCert = &clientCert - } - - if cfg.logs.upstreamCAFile != "" { - logsUpstreamCACert, err = os.ReadFile(cfg.logs.upstreamCAFile) - if err != nil { - stdlog.Fatalf("failed to read upstream logs TLS CA: %v", err) - } - - } - - if cfg.logs.upstreamCertFile != "" && cfg.logs.upstreamKeyFile != "" { - clientCert, err := stdtls.LoadX509KeyPair(cfg.logs.upstreamCertFile, cfg.logs.upstreamKeyFile) - if err != nil { - stdlog.Fatalf("failed to read upstream logs client TLS cert/key pair: %v", err) - } - logsUpstreamClientCert = &clientCert - } - - if cfg.traces.upstreamCAFile != "" { - tracesUpstreamCACert, err = os.ReadFile(cfg.traces.upstreamCAFile) - if err != nil { - stdlog.Fatalf("failed to read upstream traces TLS CA: %v", err) - } - - } - - if cfg.traces.upstreamCertFile != "" && cfg.traces.upstreamKeyFile != "" { - clientCert, err := stdtls.LoadX509KeyPair(cfg.traces.upstreamCertFile, cfg.traces.upstreamKeyFile) - if err != nil { - stdlog.Fatalf("failed to read upstream traces client TLS cert/key pair: %v", err) - } - tracesUpstreamClientCert = &clientCert - } - r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.RealIP) @@ -590,6 +537,7 @@ func main() { // registration failures per tenant. registerTenantsFailingMetric = authentication.RegisterTenantsFailingMetric(reg) pm = authentication.NewProviderManager(logger, registerTenantsFailingMetric) + tracesUpstreamTLSOptions *tls.UpstreamOptions ) r.Group(func(r chi.Router) { @@ -658,6 +606,26 @@ func main() { // Metrics. if cfg.metrics.enabled { + + var loadInterval *time.Duration + + if cfg.metrics.enableCertWatcher { + loadInterval = &cfg.tls.reloadInterval + } + + metricsUpstreamClientOptions, err := tls.NewUpstreamOptions( + context.Background(), + cfg.metrics.upstreamCertFile, + cfg.metrics.upstreamKeyFile, + cfg.metrics.upstreamCAFile, + loadInterval, + logger, + g) + + if err != nil { + stdlog.Fatalf("failed to read upstream logs TLS: %v", err) + } + eps := metricsv1.Endpoints{ ReadEndpoint: cfg.metrics.readEndpoint, WriteEndpoint: cfg.metrics.writeEndpoint, @@ -693,8 +661,7 @@ func main() { const queryParamName = "query" r.Mount("/api/v1/{tenant}", metricslegacy.NewHandler( cfg.metrics.readEndpoint, - metricsUpstreamCACert, - metricsUpstreamClientCert, + metricsUpstreamClientOptions, metricslegacy.WithLogger(logger), metricslegacy.WithRegistry(reg), metricslegacy.WithHandlerInstrumenter(instrumenter), @@ -708,8 +675,7 @@ func main() { const matchParamName = "match[]" r.Mount("/api/metrics/v1/{tenant}", metricsv1.NewHandler( eps, - metricsUpstreamCACert, - metricsUpstreamClientCert, + metricsUpstreamClientOptions, metricsv1.WithLogger(logger), metricsv1.WithRegistry(reg), metricsv1.WithHandlerInstrumenter(instrumenter), @@ -743,6 +709,26 @@ func main() { // Logs. if cfg.logs.enabled { + + var loadInterval *time.Duration + + if cfg.logs.enableCertWatcher { + loadInterval = &cfg.tls.reloadInterval + } + + logsUpstreamClientOptions, err := tls.NewUpstreamOptions( + context.Background(), + cfg.logs.upstreamCertFile, + cfg.logs.upstreamKeyFile, + cfg.logs.upstreamCAFile, + loadInterval, + logger, + g) + + if err != nil { + stdlog.Fatalf("failed to read upstream logs TLS: %v", err) + } + r.Group(func(r chi.Router) { r.Use(middleware.Timeout(cfg.logs.upstreamWriteTimeout)) r.Mount("/api/logs/v1/{tenant}", @@ -753,8 +739,7 @@ func main() { cfg.logs.writeEndpoint, cfg.logs.rulesEndpoint, cfg.logs.rulesReadOnly, - logsUpstreamCACert, - logsUpstreamClientCert, + logsUpstreamClientOptions, logsv1.Logger(logger), logsv1.WithRegistry(reg), logsv1.WithHandlerInstrumenter(instrumenter), @@ -779,6 +764,23 @@ func main() { // Traces. if cfg.traces.enabled && (cfg.traces.readEndpoint != nil || cfg.traces.readTemplateEndpoint != "" || cfg.traces.tempoEndpoint != nil) { + var loadInterval *time.Duration + if cfg.traces.enableCertWatcher { + loadInterval = &cfg.tls.reloadInterval + } + tracesUpstreamTLSOptions, err = tls.NewUpstreamOptions( + context.Background(), + cfg.traces.upstreamCertFile, + cfg.traces.upstreamKeyFile, + cfg.traces.upstreamCAFile, + loadInterval, + logger, + g) + + if err != nil { + stdlog.Fatalf("failed to read upstream traces TLS: %v", err) + } + r.Group(func(r chi.Router) { r.Use(authentication.WithTenantMiddlewares(pm.Middlewares)) r.Use(authentication.WithTenantHeader(cfg.traces.tenantHeader, tenantIDs)) @@ -804,8 +806,7 @@ func main() { cfg.traces.readTemplateEndpoint, cfg.traces.tempoEndpoint, cfg.traces.writeOTLPHTTPEndpoint, - tracesUpstreamCACert, - tracesUpstreamClientCert, + tracesUpstreamTLSOptions, tracesv1.Logger(logger), tracesv1.WithRegistry(reg), tracesv1.WithHandlerInstrumenter(instrumenter), @@ -831,6 +832,7 @@ func main() { cfg.tls.clientAuthType, cfg.tls.cipherSuites, ) + if err != nil { stdlog.Fatalf("failed to initialize tls config: %v", err) } @@ -895,8 +897,7 @@ func main() { pm.GRPCMiddlewares, authorizers, logger, - tracesUpstreamCACert, - tracesUpstreamClientCert, + tracesUpstreamTLSOptions, ) if err != nil { stdlog.Fatalf("failed to initialize gRPC server: %v", err) @@ -939,6 +940,7 @@ func main() { cfg.tls.clientAuthType, cfg.tls.cipherSuites, ) + if err != nil { stdlog.Fatalf("failed to initialize tls config: %v", err) } @@ -1111,10 +1113,13 @@ func parseFlags() (config, error) { "File containing the TLS client certificates to authenticate against upstream logs servers. Leave blank to disable mTLS.") flag.StringVar(&cfg.logs.upstreamKeyFile, "logs.tls.key-file", "", "File containing the TLS client key to authenticate against upstream logs servers. Leave blank to disable mTLS.") + flag.BoolVar(&cfg.logs.enableCertWatcher, "logs.tls.watch-certs", false, + "Watch for certificate changes and reload") flag.StringVar(&cfg.logs.tenantHeader, "logs.tenant-header", "X-Scope-OrgID", "The name of the HTTP header containing the tenant ID to forward to the logs upstream.") flag.StringVar(&cfg.logs.tenantLabel, "logs.rules.tenant-label", "tenant_id", "The name of the rules label that should hold the tenant ID in logs upstreams.") + flag.StringVar(&rawLogsWriteEndpoint, "logs.write.endpoint", "", "The endpoint against which to make write requests for logs.") flag.StringVar(&rawLogsAuthExtractSelectors, "logs.auth.extract-selectors", "", @@ -1135,6 +1140,8 @@ func parseFlags() (config, error) { "File containing the TLS client certificates to authenticate against upstream logs servers. Leave blank to disable mTLS.") flag.StringVar(&cfg.metrics.upstreamKeyFile, "metrics.tls.key-file", "", "File containing the TLS client key to authenticate against upstream metrics servers. Leave blank to disable mTLS.") + flag.BoolVar(&cfg.metrics.enableCertWatcher, "metrics.tls.watch-certs", false, + "Watch for certificate changes and reload") flag.StringVar(&cfg.metrics.tenantHeader, "metrics.tenant-header", "THANOS-TENANT", "The name of the HTTP header containing the tenant ID to forward to the metrics upstreams.") flag.StringVar(&cfg.metrics.tenantLabel, "metrics.tenant-label", "tenant_id", @@ -1157,6 +1164,8 @@ func parseFlags() (config, error) { "File containing the TLS client certificates to authenticate against upstream logs servers. Leave blank to disable mTLS.") flag.StringVar(&cfg.traces.upstreamKeyFile, "traces.tls.key-file", "", "File containing the TLS client key to authenticate against upstream traces servers. Leave blank to disable mTLS.") + flag.BoolVar(&cfg.traces.enableCertWatcher, "traces.tls.watch-certs", false, + "Watch for certificate changes and reload") flag.StringVar(&cfg.traces.tenantHeader, "traces.tenant-header", "X-Tenant", "The name of the HTTP header containing the tenant ID to forward to upstream OpenTelemetry collector.") flag.StringVar(&cfg.tls.serverCertFile, "tls.server.cert-file", "", @@ -1458,12 +1467,12 @@ var gRPCRBAC = authorization.GRPCRBac{ } func newGRPCServer(cfg *config, tenantHeader string, tenantIDs map[string]string, pmis authentication.GRPCMiddlewareFunc, - authorizers map[string]rbac.Authorizer, logger log.Logger, tracesUpstreamCA []byte, tracesUpstreamCert *stdtls.Certificate, + authorizers map[string]rbac.Authorizer, logger log.Logger, upstreamTLSOptions *tls.UpstreamOptions, ) (*grpc.Server, error) { connOtel, err := tracesv1.NewOTelConnection( cfg.traces.writeOTLPGRPCEndpoint, tracesv1.WithLogger(logger), - tracesv1.WithUpstreamTLS(tracesUpstreamCA, tracesUpstreamCert), + tracesv1.WithUpstreamTLSOptions(upstreamTLSOptions), ) if err != nil { return nil, err diff --git a/tls/ca_watcher.go b/tls/ca_watcher.go new file mode 100644 index 000000000..6f0c98cb3 --- /dev/null +++ b/tls/ca_watcher.go @@ -0,0 +1,112 @@ +package tls + +import ( + "context" + "crypto/sha256" + "crypto/x509" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "time" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" +) + +// caCertificateWatcher poll for changes on the CA certificate file, if the CA change it will add it to the certificate Pool. +type caCertificateWatcher struct { + mutex sync.RWMutex + certPool *x509.CertPool + logger log.Logger + fileHashContent string + CAPath string + interval time.Duration +} + +// newCACertificateWatcher creates a new watcher for the CA file. +func newCACertificateWatcher(CAPath string, logger log.Logger, interval time.Duration, pool *x509.CertPool) (*caCertificateWatcher, error) { + w := &caCertificateWatcher{ + CAPath: CAPath, + logger: logger, + certPool: pool, + interval: interval, + } + err := w.loadCA() + return w, err +} + +// Watch for the changes on the certificate each interval, if the content changes +// a new certificate will be added to the pool. +func (w *caCertificateWatcher) Watch(ctx context.Context) error { + var timer *time.Timer + + scheduleNext := func() { + timer = time.NewTimer(w.interval) + } + scheduleNext() + + for { + select { + case <-ctx.Done(): + return nil + case <-timer.C: + err := w.loadCA() + if err != nil { + return err + } + scheduleNext() + } + } + +} + +func (w *caCertificateWatcher) loadCA() error { + + hash, err := w.hashFile(w.CAPath) + if err != nil { + level.Error(w.logger).Log("unable to read the file", "error", err.Error()) + return err + } + + // If file changed + if w.fileHashContent != hash { + // read content + caPEM, err := os.ReadFile(filepath.Clean(w.CAPath)) + if err != nil { + level.Error(w.logger).Log("failed to load CA %s: %w", w.CAPath, err) + return err + } + w.mutex.Lock() + defer w.mutex.Unlock() + if !w.certPool.AppendCertsFromPEM(caPEM) { + level.Error(w.logger).Log("failed to parse CA %s", w.CAPath) + return err + } + } + return nil +} + +func (w *caCertificateWatcher) pool() *x509.CertPool { + w.mutex.RLock() + defer w.mutex.RUnlock() + return w.certPool +} + +// hashFile returns the SHA256 hash of the file. +func (w *caCertificateWatcher) hashFile(file string) (string, error) { + f, err := os.Open(filepath.Clean(file)) + if err != nil { + return "", err + } + + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + + return fmt.Sprintf("%x", h.Sum(nil)), nil +} diff --git a/tls/ca_watcher_test.go b/tls/ca_watcher_test.go new file mode 100644 index 000000000..2e74cdaf7 --- /dev/null +++ b/tls/ca_watcher_test.go @@ -0,0 +1,123 @@ +package tls + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "os" + "sync" + "testing" + "time" + + "github.com/observatorium/api/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + certutil "k8s.io/client-go/util/cert" +) + +func TestCertWatcher(t *testing.T) { + logger := logger.NewLogger("info", logger.LogFormatLogfmt, "") + reloadInterval := 2 * time.Second + + caA, caPathA, cleanupA, err := newSelfSignedCA("ok") + defer cleanupA() + require.NoError(t, err) + caPool := x509.NewCertPool() + caPool.AddCert(caA) + + reloader, err := newCACertificateWatcher(caPathA, logger, reloadInterval, caPool) + require.NoError(t, err) + + cancelContext, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + err := reloader.Watch(cancelContext) + require.NoError(t, err) + wg.Done() + }() + // Start watch loop + + // Generate new CA + caB, caPathB, cleanupB, err := newSelfSignedCA("baz") + defer cleanupB() + require.NoError(t, err) + + cbPool := x509.NewCertPool() + cbPool.AddCert(caB) + err = swapCert(t, caPathA, caPathB) + require.NoError(t, err) + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(caA) + rootCAs.AddCert(caB) + + assert.Eventually(t, func() bool { + return rootCAs.Equal(reloader.pool()) + + }, 5*reloadInterval, reloadInterval) + + cancel() + wg.Wait() + +} + +func newSelfSignedCA(hostname string) (*x509.Certificate, string, func(), error) { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, "", func() {}, fmt.Errorf("generation of private key failed: %v", err) + } + + ca, err := certutil.NewSelfSignedCACert(certutil.Config{CommonName: hostname}, privKey) + if err != nil { + return nil, "", func() {}, fmt.Errorf("generation of certificate, failed: %v", err) + } + + // Create a PEM block with the certificate + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: ca.Raw, + }) + + certPath, err := writeTempFile("cert", pemBytes) + if err != nil { + return nil, "", func() {}, fmt.Errorf("error writing cert data: %v", err) + } + + return ca, certPath, func() { + _ = os.Remove(certPath) + }, nil +} + +func writeTempFile(pattern string, data []byte) (string, error) { + f, err := os.CreateTemp("", pattern) + if err != nil { + return "", fmt.Errorf("error creating temp file: %v", err) + } + defer f.Close() + + n, err := f.Write(data) + if err == nil && n < len(data) { + err = io.ErrShortWrite + } + + if err != nil { + return "", fmt.Errorf("error writing temporary file: %v", err) + } + + return f.Name(), nil +} + +func swapCert(t *testing.T, caPathA, caPathB string) error { + t.Log("renaming", caPathB, "to", caPathA) + if err := os.Rename(caPathB, caPathA); err != nil { + return err + } + return nil +} diff --git a/tls/config.go b/tls/config.go index 30ad88508..bac2d0a06 100644 --- a/tls/config.go +++ b/tls/config.go @@ -2,31 +2,12 @@ package tls import ( "crypto/tls" - "crypto/x509" "fmt" "github.com/go-kit/log" "github.com/go-kit/log/level" ) -// NewClientConfig returns a tls config for the reverse proxy handling if an upstream CA is given. -func NewClientConfig(upstreamCA []byte, upstreamCert *tls.Certificate) *tls.Config { - if len(upstreamCA) == 0 { - return nil - } - - cfg := &tls.Config{ - RootCAs: x509.NewCertPool(), - } - cfg.RootCAs.AppendCertsFromPEM(upstreamCA) - - if upstreamCert != nil { - cfg.Certificates = append(cfg.Certificates, *upstreamCert) - } - - return cfg -} - // NewServerConfig provides new server TLS configuration. func NewServerConfig(logger log.Logger, certFile, keyFile, minVersion, maxVersion, clientAuthType string, cipherSuites []string) (*tls.Config, error) { if certFile == "" && keyFile == "" { diff --git a/tls/options.go b/tls/options.go new file mode 100644 index 000000000..e9e952132 --- /dev/null +++ b/tls/options.go @@ -0,0 +1,159 @@ +package tls + +import ( + "context" + stdtls "crypto/tls" + "crypto/x509" + "os" + "time" + + rbacproxytls "github.com/brancz/kube-rbac-proxy/pkg/tls" + "github.com/go-kit/log" + "github.com/oklog/run" +) + +// UpstreamOptions represents the options of the upstream TLS configuration +// this structure contains the certificates and the watchers if the certificate/ca watchers are enabled. +type UpstreamOptions struct { + cert *stdtls.Certificate + ca []byte + certReloader *rbacproxytls.CertReloader + caReloader *caCertificateWatcher +} + +// NewUpstreamOptions create a new UpstreamOptions, if interval is nil, the watcher will not be enabled. +func NewUpstreamOptions(ctx context.Context, upstreamCertFile, upstreamKeyFile, upstreamCAFile string, + interval *time.Duration, logger log.Logger, g run.Group) (*UpstreamOptions, error) { + + // reload enabled + if interval != nil { + return newWithWatchers(ctx, upstreamCertFile, upstreamKeyFile, upstreamCAFile, *interval, logger, g) + } + + return newNoWatchers(upstreamCertFile, upstreamKeyFile, upstreamCAFile) +} + +func newWithWatchers(ctx context.Context, upstreamCertFile, upstreamKeyFile, upstreamCAFile string, + interval time.Duration, logger log.Logger, g run.Group) (*UpstreamOptions, error) { + options := &UpstreamOptions{} + + if upstreamCertFile != "" && upstreamKeyFile != "" { + certReloader, err := startCertReloader(ctx, g, upstreamCertFile, upstreamKeyFile, interval) + if err != nil { + return nil, err + } + options.certReloader = certReloader + } + if upstreamCAFile != "" { + caPool := x509.NewCertPool() + caReloader, err := startCAReloader(ctx, g, upstreamCAFile, interval, logger, caPool) + if err != nil { + return nil, err + } + options.caReloader = caReloader + } + return options, nil +} + +func newNoWatchers(upstreamCertFile, upstreamKeyFile, upstreamCAFile string) (*UpstreamOptions, error) { + options := &UpstreamOptions{} + if upstreamCertFile != "" && upstreamKeyFile != "" { + cert, err := stdtls.LoadX509KeyPair(upstreamCertFile, upstreamKeyFile) + if err != nil { + return nil, err + } + options.cert = &cert + } + + if upstreamCAFile != "" { + ca, err := os.ReadFile(upstreamCAFile) + if err != nil { + return nil, err + } + options.ca = ca + } + return options, nil +} + +func startCertReloader(ctx context.Context, g run.Group, + upstreamKeyFile, upstreamCertFile string, interval time.Duration) (*rbacproxytls.CertReloader, error) { + certReloader, err := rbacproxytls.NewCertReloader( + upstreamKeyFile, + upstreamCertFile, + interval, + ) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(ctx) + g.Add(func() error { + return certReloader.Watch(ctx) + }, func(error) { + cancel() + }) + return certReloader, nil +} + +func startCAReloader(ctx context.Context, g run.Group, upstreamCAFile string, interval time.Duration, logger log.Logger, + pool *x509.CertPool) (*caCertificateWatcher, error) { + caReloader, err := newCACertificateWatcher(upstreamCAFile, logger, interval, pool) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + g.Add(func() error { + return caReloader.Watch(ctx) + }, func(error) { + cancel() + }) + return caReloader, nil +} + +// hasCA determine if the CA was specified. +func (uo *UpstreamOptions) hasCA() bool { + return len(uo.ca) != 0 || uo.caReloader != nil +} + +// hasUpstreamCerts determine if the hasUpstreamCerts were specified. +func (uo *UpstreamOptions) hasUpstreamCerts() bool { + return uo.cert != nil || uo.certReloader != nil +} + +// isCAReloadEnabled determine if the CA watcher is enabled. +func (uo *UpstreamOptions) isCAReloadEnabled() bool { + return uo.caReloader != nil +} + +// isCertReloaderEnabled determine if the certificate watcher is enabled. +func (uo *UpstreamOptions) isCertReloaderEnabled() bool { + return uo.certReloader != nil +} + +// NewClientConfig returns a tls config for the reverse proxy handling if an upstream CA is given. +// this will transform TLS UpstreamOptions to a tls.Config native TLS golang structure, if the watchers are enabled +// it will override the GetClientCertificate function. +func (uo *UpstreamOptions) NewClientConfig() *stdtls.Config { + if !uo.hasCA() { + return nil + } + cfg := &stdtls.Config{} + + if uo.hasUpstreamCerts() { + if uo.isCertReloaderEnabled() { + cfg.GetClientCertificate = func(info *stdtls.CertificateRequestInfo) (*stdtls.Certificate, error) { + return uo.certReloader.GetCertificate(nil) + } + } else { + cfg.Certificates = append(cfg.Certificates, *uo.cert) + } + } + + if uo.isCAReloadEnabled() { + cfg.RootCAs = uo.caReloader.pool() + } else { + cfg.RootCAs = x509.NewCertPool() + cfg.RootCAs.AppendCertsFromPEM(uo.ca) + } + return cfg +} diff --git a/tls/options_test.go b/tls/options_test.go new file mode 100644 index 000000000..50e3db30c --- /dev/null +++ b/tls/options_test.go @@ -0,0 +1,268 @@ +package tls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "os" + "testing" + "time" + + "github.com/observatorium/api/logger" + "github.com/oklog/run" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + certutil "k8s.io/client-go/util/cert" +) + +func newSelfSignedCert(t *testing.T, hostname string) (string, string, func(), error) { + var err error + certBytes, keyBytes, err := certutil.GenerateSelfSignedCertKey(hostname, nil, nil) + if err != nil { + t.Fatalf("generation of self signed cert and key failed: %v", err) + } + + certPath, err := writeTempFile("cert", certBytes) + if err != nil { + t.Fatalf("error writing cert data: %v", err) + return certPath, "", func() {}, err + } + keyPath, err := writeTempFile("key", keyBytes) + if err != nil { + t.Fatalf("error writing key data: %v", err) + return certPath, keyPath, func() { + _ = os.Remove(certPath) + }, err + } + + return certPath, keyPath, func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }, nil +} + +func TestUpstreamOptions_NewClientConfigNoTimeInteval(t *testing.T) { + + ca, caPath, cleanCA, err := newSelfSignedCA("ok") + defer cleanCA() + require.NoError(t, err) + + caPool := x509.NewCertPool() + caPool.AddCert(ca) + + certPath, keyPath, cleanCerts, err := newSelfSignedCert(t, "local") + defer cleanCerts() + require.NoError(t, err) + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + var g run.Group + + logger := logger.NewLogger("info", logger.LogFormatLogfmt, "") + + tests := []struct { + name string + caPath string + certPath string + keyPath string + expectedErr bool + expectedNilCfg bool + expectedRootCA *x509.CertPool + expectedCertificates []tls.Certificate + }{ + { + name: "all enabled", + caPath: caPath, + certPath: certPath, + keyPath: keyPath, + expectedRootCA: caPool, + expectedCertificates: []tls.Certificate{cert}, + expectedNilCfg: false, + }, + { + name: "cert/key empty", + caPath: caPath, + expectedRootCA: caPool, + expectedNilCfg: false, + }, + { + name: "ca empty", + certPath: certPath, + keyPath: keyPath, + expectedCertificates: []tls.Certificate{cert}, + expectedNilCfg: true, + }, + { + name: "both empty", + expectedNilCfg: true, + }, + { + name: "invalid CA", + caPath: "/nowhere", + certPath: certPath, + keyPath: keyPath, + expectedNilCfg: true, + expectedErr: true, + }, + { + name: "invalid cert", + caPath: caPath, + certPath: "/nowhere", + keyPath: keyPath, + expectedNilCfg: true, + expectedErr: true, + }, + { + name: "invalid key", + caPath: caPath, + certPath: certPath, + keyPath: "/nowhere", + expectedNilCfg: true, + expectedErr: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + opts, err := NewUpstreamOptions( + context.Background(), + tc.certPath, tc.keyPath, tc.caPath, nil, logger, g) + + if tc.expectedErr { + assert.Error(t, err) + assert.Nil(t, opts) + } else { + assert.NoError(t, err) + + cfg := opts.NewClientConfig() + if tc.expectedNilCfg { + assert.Nil(t, cfg) + } else { + assert.Equal(t, tc.expectedCertificates, cfg.Certificates) + if tc.expectedRootCA != nil { + assert.True(t, tc.expectedRootCA.Equal(cfg.RootCAs)) + } else { + assert.Nil(t, cfg.RootCAs) + } + } + } + }) + } +} + +func TestUpstreamOptions_NewClientConfigTimeInterval(t *testing.T) { + ca, caPath, cleanCA, err := newSelfSignedCA("ok") + defer cleanCA() + require.NoError(t, err) + + caPool := x509.NewCertPool() + caPool.AddCert(ca) + + certPath, keyPath, cleanCerts, err := newSelfSignedCert(t, "local") + defer cleanCerts() + require.NoError(t, err) + + _, err = tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + var g run.Group + logger := logger.NewLogger("info", logger.LogFormatLogfmt, "") + + tests := []struct { + name string + caPath string + certPath string + keyPath string + expectedErr bool + expectedNilCfg bool + setGetClientCertificate bool + expectedRootCA *x509.CertPool + expectedCertificates []tls.Certificate + }{ + { + name: "all enabled", + caPath: caPath, + certPath: certPath, + keyPath: keyPath, + expectedNilCfg: false, + setGetClientCertificate: true, + expectedRootCA: caPool, + }, + { + name: "cert/key empty", + caPath: caPath, + expectedRootCA: caPool, + setGetClientCertificate: false, + expectedNilCfg: false, + }, + { + name: "ca empty", + certPath: certPath, + keyPath: keyPath, + expectedNilCfg: true, + }, + { + name: "both empty", + expectedNilCfg: true, + }, + { + name: "invalid CA", + caPath: "/nowhere", + certPath: certPath, + keyPath: keyPath, + expectedNilCfg: true, + expectedErr: true, + }, + { + name: "invalid cert", + caPath: caPath, + certPath: "/nowhere", + keyPath: keyPath, + expectedNilCfg: true, + expectedErr: true, + }, + { + name: "invalid key", + caPath: caPath, + certPath: certPath, + keyPath: "/nowhere", + expectedNilCfg: true, + expectedErr: true, + }, + } + + interval := time.Second * 1 + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + opts, err := NewUpstreamOptions( + context.Background(), + tc.certPath, tc.keyPath, tc.caPath, &interval, logger, g) + + if tc.expectedErr { + assert.Error(t, err) + assert.Nil(t, opts) + } else { + assert.NoError(t, err) + + cfg := opts.NewClientConfig() + if tc.expectedNilCfg { + assert.Nil(t, cfg) + } else { + assert.Equal(t, tc.expectedCertificates, cfg.Certificates) + if tc.expectedRootCA != nil { + assert.True(t, tc.expectedRootCA.Equal(cfg.RootCAs)) + } else { + assert.Nil(t, cfg.RootCAs) + } + + if tc.setGetClientCertificate { + assert.NotNil(t, cfg.GetClientCertificate) + } else { + assert.Nil(t, cfg.GetClientCertificate) + } + } + } + }) + } +}