diff --git a/pkg/config/dynamic/fixtures/sample.toml b/pkg/config/dynamic/fixtures/sample.toml index b3c84571ce..9ac084b3a6 100644 --- a/pkg/config/dynamic/fixtures/sample.toml +++ b/pkg/config/dynamic/fixtures/sample.toml @@ -425,14 +425,14 @@ scheme = "foobar" path = "foobar" port = 42 - interval = "foobar" - timeout = "foobar" + interval = "1s" + timeout = "1s" hostname = "foobar" [http.services.Service0.loadBalancer.healthCheck.headers] name0 = "foobar" name1 = "foobar" [http.services.Service0.loadBalancer.responseForwarding] - flushInterval = "foobar" + flushInterval = "100ms" [tcp] [tcp.routers] diff --git a/pkg/config/dynamic/http_config.go b/pkg/config/dynamic/http_config.go index af97325d7c..cdc713de49 100644 --- a/pkg/config/dynamic/http_config.go +++ b/pkg/config/dynamic/http_config.go @@ -9,6 +9,19 @@ import ( "github.com/traefik/traefik/v2/pkg/types" ) +const ( + // DefaultHealthCheckInterval is the default value for the ServerHealthCheck interval. + DefaultHealthCheckInterval = ptypes.Duration(30 * time.Second) + // DefaultHealthCheckTimeout is the default value for the ServerHealthCheck timeout. + DefaultHealthCheckTimeout = ptypes.Duration(5 * time.Second) + + // DefaultPassHostHeader is the default value for the ServersLoadBalancer passHostHeader. + DefaultPassHostHeader = true + + // DefaultFlushInterval is the default value for the ResponseForwarding flush interval. + DefaultFlushInterval = ptypes.Duration(100 * time.Millisecond) +) + // +k8s:deepcopy-gen=true // HTTPConfiguration contains all the HTTP configuration parameters. @@ -192,7 +205,7 @@ type ResponseForwarding struct { // This configuration is ignored when ReverseProxy recognizes a response as a streaming response; // for such responses, writes are flushed to the client immediately. // Default: 100ms - FlushInterval string `json:"flushInterval,omitempty" toml:"flushInterval,omitempty" yaml:"flushInterval,omitempty" export:"true"` + FlushInterval ptypes.Duration `json:"flushInterval,omitempty" toml:"flushInterval,omitempty" yaml:"flushInterval,omitempty" export:"true"` } // +k8s:deepcopy-gen=true @@ -213,14 +226,14 @@ func (s *Server) SetDefaults() { // ServerHealthCheck holds the HealthCheck configuration. type ServerHealthCheck struct { - Scheme string `json:"scheme,omitempty" toml:"scheme,omitempty" yaml:"scheme,omitempty" export:"true"` - Path string `json:"path,omitempty" toml:"path,omitempty" yaml:"path,omitempty" export:"true"` - Method string `json:"method,omitempty" toml:"method,omitempty" yaml:"method,omitempty" export:"true"` - Port int `json:"port,omitempty" toml:"port,omitempty,omitzero" yaml:"port,omitempty" export:"true"` - // TODO change string to ptypes.Duration - Interval string `json:"interval,omitempty" toml:"interval,omitempty" yaml:"interval,omitempty" export:"true"` - // TODO change string to ptypes.Duration - Timeout string `json:"timeout,omitempty" toml:"timeout,omitempty" yaml:"timeout,omitempty" export:"true"` + Scheme string `json:"scheme,omitempty" toml:"scheme,omitempty" yaml:"scheme,omitempty" export:"true"` + Mode string `json:"mode,omitempty" toml:"mode,omitempty" yaml:"mode,omitempty" export:"true"` + Path string `json:"path,omitempty" toml:"path,omitempty" yaml:"path,omitempty" export:"true"` + Method string `json:"method,omitempty" toml:"method,omitempty" yaml:"method,omitempty" export:"true"` + Status int `json:"status,omitempty" toml:"status,omitempty" yaml:"status,omitempty" export:"true"` + Port int `json:"port,omitempty" toml:"port,omitempty,omitzero" yaml:"port,omitempty" export:"true"` + Interval ptypes.Duration `json:"interval,omitempty" toml:"interval,omitempty" yaml:"interval,omitempty" export:"true"` + Timeout ptypes.Duration `json:"timeout,omitempty" toml:"timeout,omitempty" yaml:"timeout,omitempty" export:"true"` Hostname string `json:"hostname,omitempty" toml:"hostname,omitempty" yaml:"hostname,omitempty"` FollowRedirects *bool `json:"followRedirects" toml:"followRedirects" yaml:"followRedirects" export:"true"` Headers map[string]string `json:"headers,omitempty" toml:"headers,omitempty" yaml:"headers,omitempty" export:"true"` diff --git a/pkg/config/label/label_test.go b/pkg/config/label/label_test.go index a11fb3976f..004825beed 100644 --- a/pkg/config/label/label_test.go +++ b/pkg/config/label/label_test.go @@ -149,15 +149,15 @@ func TestDecodeConfiguration(t *testing.T) { "traefik.http.services.Service0.loadbalancer.healthcheck.headers.name0": "foobar", "traefik.http.services.Service0.loadbalancer.healthcheck.headers.name1": "foobar", "traefik.http.services.Service0.loadbalancer.healthcheck.hostname": "foobar", - "traefik.http.services.Service0.loadbalancer.healthcheck.interval": "foobar", + "traefik.http.services.Service0.loadbalancer.healthcheck.interval": "1s", "traefik.http.services.Service0.loadbalancer.healthcheck.path": "foobar", "traefik.http.services.Service0.loadbalancer.healthcheck.method": "foobar", "traefik.http.services.Service0.loadbalancer.healthcheck.port": "42", "traefik.http.services.Service0.loadbalancer.healthcheck.scheme": "foobar", - "traefik.http.services.Service0.loadbalancer.healthcheck.timeout": "foobar", + "traefik.http.services.Service0.loadbalancer.healthcheck.timeout": "1s", "traefik.http.services.Service0.loadbalancer.healthcheck.followredirects": "true", "traefik.http.services.Service0.loadbalancer.passhostheader": "true", - "traefik.http.services.Service0.loadbalancer.responseforwarding.flushinterval": "foobar", + "traefik.http.services.Service0.loadbalancer.responseforwarding.flushinterval": "100ms", "traefik.http.services.Service0.loadbalancer.server.scheme": "foobar", "traefik.http.services.Service0.loadbalancer.server.port": "8080", "traefik.http.services.Service0.loadbalancer.sticky.cookie.name": "foobar", @@ -165,15 +165,15 @@ func TestDecodeConfiguration(t *testing.T) { "traefik.http.services.Service1.loadbalancer.healthcheck.headers.name0": "foobar", "traefik.http.services.Service1.loadbalancer.healthcheck.headers.name1": "foobar", "traefik.http.services.Service1.loadbalancer.healthcheck.hostname": "foobar", - "traefik.http.services.Service1.loadbalancer.healthcheck.interval": "foobar", + "traefik.http.services.Service1.loadbalancer.healthcheck.interval": "1s", "traefik.http.services.Service1.loadbalancer.healthcheck.path": "foobar", "traefik.http.services.Service1.loadbalancer.healthcheck.method": "foobar", "traefik.http.services.Service1.loadbalancer.healthcheck.port": "42", "traefik.http.services.Service1.loadbalancer.healthcheck.scheme": "foobar", - "traefik.http.services.Service1.loadbalancer.healthcheck.timeout": "foobar", + "traefik.http.services.Service1.loadbalancer.healthcheck.timeout": "1s", "traefik.http.services.Service1.loadbalancer.healthcheck.followredirects": "true", "traefik.http.services.Service1.loadbalancer.passhostheader": "true", - "traefik.http.services.Service1.loadbalancer.responseforwarding.flushinterval": "foobar", + "traefik.http.services.Service1.loadbalancer.responseforwarding.flushinterval": "100ms", "traefik.http.services.Service1.loadbalancer.server.scheme": "foobar", "traefik.http.services.Service1.loadbalancer.server.port": "8080", "traefik.http.services.Service1.loadbalancer.sticky": "false", @@ -658,8 +658,8 @@ func TestDecodeConfiguration(t *testing.T) { Path: "foobar", Method: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", Headers: map[string]string{ "name0": "foobar", @@ -669,7 +669,7 @@ func TestDecodeConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, @@ -686,8 +686,8 @@ func TestDecodeConfiguration(t *testing.T) { Path: "foobar", Method: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", Headers: map[string]string{ "name0": "foobar", @@ -697,7 +697,7 @@ func TestDecodeConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, @@ -1162,8 +1162,8 @@ func TestEncodeConfiguration(t *testing.T) { Path: "foobar", Method: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", Headers: map[string]string{ "name0": "foobar", @@ -1172,7 +1172,7 @@ func TestEncodeConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, @@ -1189,8 +1189,8 @@ func TestEncodeConfiguration(t *testing.T) { Path: "foobar", Method: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", Headers: map[string]string{ "name0": "foobar", @@ -1199,7 +1199,7 @@ func TestEncodeConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, @@ -1359,14 +1359,15 @@ func TestEncodeConfiguration(t *testing.T) { "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Headers.name1": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Hostname": "foobar", - "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Interval": "foobar", + "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Interval": "1000000000", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Path": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Method": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Port": "42", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Scheme": "foobar", - "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Timeout": "foobar", + "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Status": "0", + "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Timeout": "1000000000", "traefik.HTTP.Services.Service0.LoadBalancer.PassHostHeader": "true", - "traefik.HTTP.Services.Service0.LoadBalancer.ResponseForwarding.FlushInterval": "foobar", + "traefik.HTTP.Services.Service0.LoadBalancer.ResponseForwarding.FlushInterval": "100000000", "traefik.HTTP.Services.Service0.LoadBalancer.server.Port": "8080", "traefik.HTTP.Services.Service0.LoadBalancer.server.Scheme": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.Sticky.Cookie.Name": "foobar", @@ -1375,14 +1376,15 @@ func TestEncodeConfiguration(t *testing.T) { "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Headers.name0": "foobar", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Headers.name1": "foobar", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Hostname": "foobar", - "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Interval": "foobar", + "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Interval": "1000000000", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Path": "foobar", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Method": "foobar", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Port": "42", "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Scheme": "foobar", - "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Timeout": "foobar", + "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Status": "0", + "traefik.HTTP.Services.Service1.LoadBalancer.HealthCheck.Timeout": "1000000000", "traefik.HTTP.Services.Service1.LoadBalancer.PassHostHeader": "true", - "traefik.HTTP.Services.Service1.LoadBalancer.ResponseForwarding.FlushInterval": "foobar", + "traefik.HTTP.Services.Service1.LoadBalancer.ResponseForwarding.FlushInterval": "100000000", "traefik.HTTP.Services.Service1.LoadBalancer.server.Port": "8080", "traefik.HTTP.Services.Service1.LoadBalancer.server.Scheme": "foobar", "traefik.HTTP.Services.Service0.LoadBalancer.HealthCheck.Headers.name0": "foobar", diff --git a/pkg/config/runtime/runtime.go b/pkg/config/runtime/runtime.go index 48d527f65e..8f9f2a4355 100644 --- a/pkg/config/runtime/runtime.go +++ b/pkg/config/runtime/runtime.go @@ -15,6 +15,12 @@ const ( StatusWarning = "warning" ) +// Status of the servers. +const ( + StatusUp = "UP" + StatusDown = "DOWN" +) + // Configuration holds the information about the currently running traefik instance. type Configuration struct { Routers map[string]*RouterInfo `json:"routers,omitempty"` diff --git a/pkg/config/runtime/runtime_test.go b/pkg/config/runtime/runtime_test.go index 7b309b53df..e959ca1629 100644 --- a/pkg/config/runtime/runtime_test.go +++ b/pkg/config/runtime/runtime_test.go @@ -2,9 +2,11 @@ package runtime_test import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" ) @@ -49,7 +51,7 @@ func TestPopulateUsedBy(t *testing.T) { {URL: "http://127.0.0.1:8086"}, }, HealthCheck: &dynamic.ServerHealthCheck{ - Interval: "500ms", + Interval: ptypes.Duration(500 * time.Millisecond), Path: "/health", }, }, @@ -159,7 +161,7 @@ func TestPopulateUsedBy(t *testing.T) { }, }, HealthCheck: &dynamic.ServerHealthCheck{ - Interval: "500ms", + Interval: ptypes.Duration(500 * time.Millisecond), Path: "/health", }, }, @@ -177,7 +179,7 @@ func TestPopulateUsedBy(t *testing.T) { }, }, HealthCheck: &dynamic.ServerHealthCheck{ - Interval: "500ms", + Interval: ptypes.Duration(500 * time.Millisecond), Path: "/health", }, }, diff --git a/pkg/healthcheck/healthcheck.go b/pkg/healthcheck/healthcheck.go index 8bd8d1e17c..0fc0d56bf4 100644 --- a/pkg/healthcheck/healthcheck.go +++ b/pkg/healthcheck/healthcheck.go @@ -8,425 +8,258 @@ import ( "net/http" "net/url" "strconv" - "strings" - "sync" "time" gokitmetrics "github.com/go-kit/kit/metrics" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/log" - "github.com/traefik/traefik/v2/pkg/metrics" - "github.com/traefik/traefik/v2/pkg/safe" - "github.com/vulcand/oxy/v2/roundrobin" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" ) -const ( - serverUp = "UP" - serverDown = "DOWN" -) - -var ( - singleton *HealthCheck - once sync.Once -) +const modeGRPC = "grpc" -// Balancer is the set of operations required to manage the list of servers in a load-balancer. -type Balancer interface { - Servers() []*url.URL - RemoveServer(u *url.URL) error - UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error +// StatusSetter should be implemented by a service that, when the status of a +// registered target change, needs to be notified of that change. +type StatusSetter interface { + SetStatus(ctx context.Context, childName string, up bool) } -// BalancerHandler includes functionality for load-balancing management. -type BalancerHandler interface { - ServeHTTP(w http.ResponseWriter, req *http.Request) - Balancer -} - -// BalancerStatusHandler is an http Handler that does load-balancing, -// and updates its parents of its status. -type BalancerStatusHandler interface { - BalancerHandler - StatusUpdater +// StatusUpdater should be implemented by a service that, when its status +// changes (e.g. all if its children are down), needs to propagate upwards (to +// their parent(s)) that change. +type StatusUpdater interface { + RegisterStatusUpdater(fn func(up bool)) error } -type metricsHealthcheck struct { - serverUpGauge gokitmetrics.Gauge +type metricsHealthCheck interface { + ServiceServerUpGauge() gokitmetrics.Gauge } -// Options are the public health check options. -type Options struct { - Headers map[string]string - Hostname string - Scheme string - Path string - Method string - Port int - FollowRedirects bool - Transport http.RoundTripper - Interval time.Duration - Timeout time.Duration - LB Balancer -} +type ServiceHealthChecker struct { + balancer StatusSetter + info *runtime.ServiceInfo -func (opt Options) String() string { - return fmt.Sprintf("[Hostname: %s Headers: %v Path: %s Method: %s Port: %d Interval: %s Timeout: %s FollowRedirects: %v]", opt.Hostname, opt.Headers, opt.Path, opt.Method, opt.Port, opt.Interval, opt.Timeout, opt.FollowRedirects) -} + config *dynamic.ServerHealthCheck + interval time.Duration + timeout time.Duration -type backendURL struct { - url *url.URL - weight int -} + metrics metricsHealthCheck -// BackendConfig HealthCheck configuration for a backend. -type BackendConfig struct { - Options - name string - disabledURLs []backendURL + client *http.Client + targets map[string]*url.URL } -func (b *BackendConfig) newRequest(serverURL *url.URL) (*http.Request, error) { - u, err := serverURL.Parse(b.Path) - if err != nil { - return nil, err - } +func NewServiceHealthChecker(ctx context.Context, metrics metricsHealthCheck, config *dynamic.ServerHealthCheck, service StatusSetter, info *runtime.ServiceInfo, transport http.RoundTripper, targets map[string]*url.URL) *ServiceHealthChecker { + logger := log.FromContext(ctx) - if len(b.Scheme) > 0 { - u.Scheme = b.Scheme + interval := time.Duration(config.Interval) + if interval <= 0 { + logger.Error("Health check interval smaller than zero") + interval = time.Duration(dynamic.DefaultHealthCheckInterval) } - if b.Port != 0 { - u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port)) + timeout := time.Duration(config.Timeout) + if timeout <= 0 { + logger.Error("Health check timeout smaller than zero") + timeout = time.Duration(dynamic.DefaultHealthCheckTimeout) } - return http.NewRequest(http.MethodGet, u.String(), http.NoBody) -} - -// setRequestOptions sets all request options present on the BackendConfig. -func (b *BackendConfig) setRequestOptions(req *http.Request) *http.Request { - if b.Options.Hostname != "" { - req.Host = b.Options.Hostname + if timeout >= interval { + logger.Warnf("Health check timeout should be lower than the health check interval. Interval set to timeout + 1 second (%s).", interval) + interval = timeout + time.Second } - for k, v := range b.Options.Headers { - req.Header.Set(k, v) + client := &http.Client{ + Transport: transport, } - if b.Options.Method != "" { - req.Method = strings.ToUpper(b.Options.Method) - } - - return req -} - -// HealthCheck struct. -type HealthCheck struct { - Backends map[string]*BackendConfig - metrics metricsHealthcheck - cancel context.CancelFunc -} - -// SetBackendsConfiguration set backends configuration. -func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendConfig) { - hc.Backends = backends - if hc.cancel != nil { - hc.cancel() + if config.FollowRedirects != nil && !*config.FollowRedirects { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } } - ctx, cancel := context.WithCancel(parentCtx) - hc.cancel = cancel - for _, backend := range backends { - safe.Go(func() { - hc.execute(ctx, backend) - }) + return &ServiceHealthChecker{ + balancer: service, + info: info, + config: config, + interval: interval, + timeout: timeout, + targets: targets, + client: client, + metrics: metrics, } } -func (hc *HealthCheck) execute(ctx context.Context, backend *BackendConfig) { - logger := log.FromContext(ctx) - - logger.Debugf("Initial health check for backend: %q", backend.name) - hc.checkServersLB(ctx, backend) - - ticker := time.NewTicker(backend.Interval) +func (shc *ServiceHealthChecker) Launch(ctx context.Context) { + ticker := time.NewTicker(shc.interval) defer ticker.Stop() + for { select { case <-ctx.Done(): - logger.Debugf("Stopping current health check goroutines of backend: %s", backend.name) return - case <-ticker.C: - logger.Debugf("Routine health check refresh for backend: %s", backend.name) - hc.checkServersLB(ctx, backend) - } - } -} - -func (hc *HealthCheck) checkServersLB(ctx context.Context, backend *BackendConfig) { - logger := log.FromContext(ctx) - enabledURLs := backend.LB.Servers() + case <-ticker.C: + for proxyName, target := range shc.targets { + select { + case <-ctx.Done(): + return + default: + } - var newDisabledURLs []backendURL - for _, disabledURL := range backend.disabledURLs { - serverUpMetricValue := float64(0) + up := true + serverUpMetricValue := float64(1) - if err := checkHealth(disabledURL.url, backend); err == nil { - logger.Warnf("Health check up: returning to server list. Backend: %q URL: %q Weight: %d", - backend.name, disabledURL.url.String(), disabledURL.weight) - if err = backend.LB.UpsertServer(disabledURL.url, roundrobin.Weight(disabledURL.weight)); err != nil { - logger.Error(err) - } - serverUpMetricValue = 1 - } else { - logger.Warnf("Health check still failing. Backend: %q URL: %q Reason: %s", backend.name, disabledURL.url.String(), err) - newDisabledURLs = append(newDisabledURLs, disabledURL) - } + if err := shc.executeHealthCheck(ctx, shc.config, target); err != nil { + // The context is canceled when the dynamic configuration is refreshed. + if errors.Is(err, context.Canceled) { + return + } - labelValues := []string{"service", backend.name, "url", disabledURL.url.String()} - hc.metrics.serverUpGauge.With(labelValues...).Set(serverUpMetricValue) - } + log.FromContext(ctx).WithError(err).WithField("targetURL", target.String()).Error("Health check failed.") - backend.disabledURLs = newDisabledURLs + up = false + serverUpMetricValue = float64(0) + } - for _, enabledURL := range enabledURLs { - serverUpMetricValue := float64(1) + shc.balancer.SetStatus(ctx, proxyName, up) - if err := checkHealth(enabledURL, backend); err != nil { - weight := 1 - rr, ok := backend.LB.(*roundrobin.RoundRobin) - if ok { - var gotWeight bool - weight, gotWeight = rr.ServerWeight(enabledURL) - if !gotWeight { - weight = 1 + statusStr := runtime.StatusDown + if up { + statusStr = runtime.StatusUp } - } - logger.Warnf("Health check failed, removing from server list. Backend: %q URL: %q Weight: %d Reason: %s", - backend.name, enabledURL.String(), weight, err) - if err := backend.LB.RemoveServer(enabledURL); err != nil { - logger.Error(err) - } + shc.info.UpdateServerStatus(target.String(), statusStr) - backend.disabledURLs = append(backend.disabledURLs, backendURL{enabledURL, weight}) - serverUpMetricValue = 0 + shc.metrics.ServiceServerUpGauge(). + With("service", proxyName, "url", target.String()). + Set(serverUpMetricValue) + } } - - labelValues := []string{"service", backend.name, "url", enabledURL.String()} - hc.metrics.serverUpGauge.With(labelValues...).Set(serverUpMetricValue) } } -// GetHealthCheck returns the health check which is guaranteed to be a singleton. -func GetHealthCheck(registry metrics.Registry) *HealthCheck { - once.Do(func() { - singleton = newHealthCheck(registry) - }) - return singleton -} +func (shc *ServiceHealthChecker) executeHealthCheck(ctx context.Context, config *dynamic.ServerHealthCheck, target *url.URL) error { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(shc.timeout)) + defer cancel() -func newHealthCheck(registry metrics.Registry) *HealthCheck { - return &HealthCheck{ - Backends: make(map[string]*BackendConfig), - metrics: metricsHealthcheck{ - serverUpGauge: registry.ServiceServerUpGauge(), - }, + if config.Mode == modeGRPC { + return shc.checkHealthGRPC(ctx, target) } + return shc.checkHealthHTTP(ctx, target) } -// NewBackendConfig Instantiate a new BackendConfig. -func NewBackendConfig(options Options, backendName string) *BackendConfig { - return &BackendConfig{ - Options: options, - name: backendName, - } -} - -// checkHealth returns a nil error in case it was successful and otherwise -// a non-nil error with a meaningful description why the health check failed. -func checkHealth(serverURL *url.URL, backend *BackendConfig) error { - req, err := backend.newRequest(serverURL) +// checkHealthHTTP returns an error with a meaningful description if the health check failed. +// Dedicated to HTTP servers. +func (shc *ServiceHealthChecker) checkHealthHTTP(ctx context.Context, target *url.URL) error { + req, err := shc.newRequest(ctx, target) if err != nil { - return fmt.Errorf("failed to create HTTP request: %w", err) + return fmt.Errorf("create HTTP request: %w", err) } - req = backend.setRequestOptions(req) - - client := http.Client{ - Timeout: backend.Options.Timeout, - Transport: backend.Options.Transport, - } - - if !backend.FollowRedirects { - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - } - - resp, err := client.Do(req) + resp, err := shc.client.Do(req) if err != nil { return fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + if shc.config.Status == 0 && (resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest) { return fmt.Errorf("received error status code: %v", resp.StatusCode) } - return nil -} + if shc.config.Status != 0 && shc.config.Status != resp.StatusCode { + return fmt.Errorf("received error status code: %v expected status code: %v", resp.StatusCode, shc.config.Status) + } -// StatusUpdater should be implemented by a service that, when its status -// changes (e.g. all if its children are down), needs to propagate upwards (to -// their parent(s)) that change. -type StatusUpdater interface { - RegisterStatusUpdater(fn func(up bool)) error + return nil } -// NewLBStatusUpdater returns a new LbStatusUpdater. -func NewLBStatusUpdater(bh BalancerHandler, info *runtime.ServiceInfo, hc *dynamic.ServerHealthCheck) *LbStatusUpdater { - return &LbStatusUpdater{ - BalancerHandler: bh, - serviceInfo: info, - wantsHealthCheck: hc != nil, +func (shc *ServiceHealthChecker) newRequest(ctx context.Context, target *url.URL) (*http.Request, error) { + u, err := target.Parse(shc.config.Path) + if err != nil { + return nil, err } -} -// LbStatusUpdater wraps a BalancerHandler and a ServiceInfo, -// so it can keep track of the status of a server in the ServiceInfo. -type LbStatusUpdater struct { - BalancerHandler - serviceInfo *runtime.ServiceInfo // can be nil - updaters []func(up bool) - wantsHealthCheck bool -} - -// RegisterStatusUpdater adds fn to the list of hooks that are run when the -// status of the Balancer changes. -// Not thread safe. -func (lb *LbStatusUpdater) RegisterStatusUpdater(fn func(up bool)) error { - if !lb.wantsHealthCheck { - return errors.New("healthCheck not enabled in config for this loadbalancer service") + if len(shc.config.Scheme) > 0 { + u.Scheme = shc.config.Scheme } - lb.updaters = append(lb.updaters, fn) - return nil -} + if shc.config.Port != 0 { + u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(shc.config.Port)) + } -// RemoveServer removes the given server from the BalancerHandler, -// and updates the status of the server to "DOWN". -func (lb *LbStatusUpdater) RemoveServer(u *url.URL) error { - // TODO(mpl): when we have the freedom to change the signature of RemoveServer - // (kinda stuck because of oxy for now), let's pass around a context to improve - // logging. - ctx := context.TODO() - upBefore := len(lb.BalancerHandler.Servers()) > 0 - err := lb.BalancerHandler.RemoveServer(u) + req, err := http.NewRequestWithContext(ctx, shc.config.Method, u.String(), http.NoBody) if err != nil { - return err + return nil, fmt.Errorf("failed to create HTTP request: %w", err) } - if lb.serviceInfo != nil { - lb.serviceInfo.UpdateServerStatus(u.String(), serverDown) - } - log.FromContext(ctx).Debugf("child %s now %s", u.String(), serverDown) - if !upBefore { - // we were already down, and we still are, no need to propagate. - log.FromContext(ctx).Debugf("Still %s, no need to propagate", serverDown) - return nil - } - if len(lb.BalancerHandler.Servers()) > 0 { - // we were up, and we still are, no need to propagate - log.FromContext(ctx).Debugf("Still %s, no need to propagate", serverUp) - return nil + if shc.config.Hostname != "" { + req.Host = shc.config.Hostname } - log.FromContext(ctx).Debugf("Propagating new %s status", serverDown) - for _, fn := range lb.updaters { - fn(false) + for k, v := range shc.config.Headers { + req.Header.Set(k, v) } - return nil + + return req, nil } -// UpsertServer adds the given server to the BalancerHandler, -// and updates the status of the server to "UP". -func (lb *LbStatusUpdater) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - ctx := context.TODO() - upBefore := len(lb.BalancerHandler.Servers()) > 0 - err := lb.BalancerHandler.UpsertServer(u, options...) +// checkHealthGRPC returns an error with a meaningful description if the health check failed. +// Dedicated to gRPC servers implementing gRPC Health Checking Protocol v1. +func (shc *ServiceHealthChecker) checkHealthGRPC(ctx context.Context, serverURL *url.URL) error { + u, err := serverURL.Parse(shc.config.Path) if err != nil { - return err + return fmt.Errorf("failed to parse server URL: %w", err) } - if lb.serviceInfo != nil { - lb.serviceInfo.UpdateServerStatus(u.String(), serverUp) - } - log.FromContext(ctx).Debugf("child %s now %s", u.String(), serverUp) - if upBefore { - // we were up, and we still are, no need to propagate - log.FromContext(ctx).Debugf("Still %s, no need to propagate", serverUp) - return nil + port := u.Port() + if shc.config.Port != 0 { + port = strconv.Itoa(shc.config.Port) } - log.FromContext(ctx).Debugf("Propagating new %s status", serverUp) - for _, fn := range lb.updaters { - fn(true) - } - return nil -} + serverAddr := net.JoinHostPort(u.Hostname(), port) -// Balancers is a list of Balancers(s) that implements the Balancer interface. -type Balancers []Balancer - -// Servers returns the deduplicated server URLs from all the Balancer. -// Note that the deduplication is only possible because all the underlying -// balancers are of the same kind (the oxy implementation). -// The comparison property is the same as the one found at: -// https://github.com/vulcand/oxy/blob/fb2728c857b7973a27f8de2f2190729c0f22cf49/roundrobin/rr.go#L347. -func (b Balancers) Servers() []*url.URL { - seen := make(map[string]struct{}) - - var servers []*url.URL - for _, lb := range b { - for _, server := range lb.Servers() { - key := serverKey(server) - if _, ok := seen[key]; ok { - continue - } + var opts []grpc.DialOption + switch shc.config.Scheme { + case "http", "h2c", "": + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } - servers = append(servers, server) - seen[key] = struct{}{} + conn, err := grpc.DialContext(ctx, serverAddr, opts...) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("fail to connect to %s within %s: %w", serverAddr, shc.config.Timeout, err) } + return fmt.Errorf("fail to connect to %s: %w", serverAddr, err) } + defer func() { _ = conn.Close() }() - return servers -} - -// RemoveServer removes the given server from all the Balancer, -// and updates the status of the server to "DOWN". -func (b Balancers) RemoveServer(u *url.URL) error { - for _, lb := range b { - if err := lb.RemoveServer(u); err != nil { - return err + resp, err := healthpb.NewHealthClient(conn).Check(ctx, &healthpb.HealthCheckRequest{}) + if err != nil { + if stat, ok := status.FromError(err); ok { + switch stat.Code() { + case codes.Unimplemented: + return fmt.Errorf("gRPC server does not implement the health protocol: %w", err) + case codes.DeadlineExceeded: + return fmt.Errorf("gRPC health check timeout: %w", err) + case codes.Canceled: + return context.Canceled + } } + + return fmt.Errorf("gRPC health check failed: %w", err) } - return nil -} -// UpsertServer adds the given server to all the Balancer, -// and updates the status of the server to "UP". -func (b Balancers) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - for _, lb := range b { - if err := lb.UpsertServer(u, options...); err != nil { - return err - } + if resp.Status != healthpb.HealthCheckResponse_SERVING { + return fmt.Errorf("received gRPC status code: %v", resp.Status) } - return nil -} -func serverKey(u *url.URL) string { - return u.Path + u.Host + u.Scheme + return nil } diff --git a/pkg/healthcheck/healthcheck_test.go b/pkg/healthcheck/healthcheck_test.go index 73254b27e9..3e80ea3ad7 100644 --- a/pkg/healthcheck/healthcheck_test.go +++ b/pkg/healthcheck/healthcheck_test.go @@ -11,347 +11,181 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + ptypes "github.com/traefik/paerser/types" + "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/testhelpers" - "github.com/vulcand/oxy/v2/roundrobin" + healthpb "google.golang.org/grpc/health/grpc_health_v1" ) -const ( - healthCheckInterval = 200 * time.Millisecond - healthCheckTimeout = 100 * time.Millisecond -) - -const delta float64 = 1e-10 - -type testHandler struct { - done func() - healthSequence []int -} - -func TestSetBackendsConfiguration(t *testing.T) { - testCases := []struct { - desc string - startHealthy bool - healthSequence []int - expectedNumRemovedServers int - expectedNumUpsertedServers int - expectedGaugeValue float64 - }{ - { - desc: "healthy server staying healthy", - startHealthy: true, - healthSequence: []int{http.StatusOK}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 1, - }, - { - desc: "healthy server staying healthy (StatusNoContent)", - startHealthy: true, - healthSequence: []int{http.StatusNoContent}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 1, - }, - { - desc: "healthy server staying healthy (StatusPermanentRedirect)", - startHealthy: true, - healthSequence: []int{http.StatusPermanentRedirect}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 1, - }, - { - desc: "healthy server becoming sick", - startHealthy: true, - healthSequence: []int{http.StatusServiceUnavailable}, - expectedNumRemovedServers: 1, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 0, - }, - { - desc: "sick server becoming healthy", - startHealthy: false, - healthSequence: []int{http.StatusOK}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 1, - expectedGaugeValue: 1, - }, - { - desc: "sick server staying sick", - startHealthy: false, - healthSequence: []int{http.StatusServiceUnavailable}, - expectedNumRemovedServers: 0, - expectedNumUpsertedServers: 0, - expectedGaugeValue: 0, - }, - { - desc: "healthy server toggling to sick and back to healthy", - startHealthy: true, - healthSequence: []int{http.StatusServiceUnavailable, http.StatusOK}, - expectedNumRemovedServers: 1, - expectedNumUpsertedServers: 1, - expectedGaugeValue: 1, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - // The context is passed to the health check and canonically canceled by - // the test server once all expected requests have been received. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - ts := newTestServer(cancel, test.healthSequence) - defer ts.Close() - - lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}} - backend := NewBackendConfig(Options{ - Path: "/path", - Interval: healthCheckInterval, - Timeout: healthCheckTimeout, - LB: lb, - }, "backendName") - - serverURL := testhelpers.MustParseURL(ts.URL) - if test.startHealthy { - lb.servers = append(lb.servers, serverURL) - } else { - backend.disabledURLs = append(backend.disabledURLs, backendURL{url: serverURL, weight: 1}) - } - - collectingMetrics := &testhelpers.CollectingGauge{} - check := HealthCheck{ - Backends: make(map[string]*BackendConfig), - metrics: metricsHealthcheck{serverUpGauge: collectingMetrics}, - } - - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - check.execute(ctx, backend) - wg.Done() - }() - - // Make test timeout dependent on number of expected requests, health - // check interval, and a safety margin. - timeout := time.Duration(len(test.healthSequence)*int(healthCheckInterval) + 500) - select { - case <-time.After(timeout): - t.Fatal("test did not complete in time") - case <-ctx.Done(): - wg.Wait() - } - - lb.Lock() - defer lb.Unlock() - - assert.Equal(t, test.expectedNumRemovedServers, lb.numRemovedServers, "removed servers") - assert.Equal(t, test.expectedNumUpsertedServers, lb.numUpsertedServers, "upserted servers") - assert.InDelta(t, test.expectedGaugeValue, collectingMetrics.GaugeValue, delta, "ServerUp Gauge") - }) - } -} - -func TestNewRequest(t *testing.T) { - type expected struct { - err bool - value string - } - +func TestServiceHealthChecker_newRequest(t *testing.T) { testCases := []struct { - desc string - serverURL string - options Options - expected expected + desc string + targetURL string + config dynamic.ServerHealthCheck + expTarget string + expError bool + expHostname string + expHeader string + expMethod string }{ { desc: "no port override", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: "/test", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1:80/test", - }, + expError: false, + expTarget: "http://backend1:80/test", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, { desc: "port override", - serverURL: "http://backend2:80", - options: Options{ + targetURL: "http://backend2:80", + config: dynamic.ServerHealthCheck{ Path: "/test", Port: 8080, }, - expected: expected{ - err: false, - value: "http://backend2:8080/test", - }, + expError: false, + expTarget: "http://backend2:8080/test", + expHostname: "backend2:8080", + expMethod: http.MethodGet, }, { desc: "no port override with no port in server URL", - serverURL: "http://backend1", - options: Options{ + targetURL: "http://backend1", + config: dynamic.ServerHealthCheck{ Path: "/health", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1/health", - }, + expError: false, + expTarget: "http://backend1/health", + expHostname: "backend1", + expMethod: http.MethodGet, }, { desc: "port override with no port in server URL", - serverURL: "http://backend2", - options: Options{ + targetURL: "http://backend2", + config: dynamic.ServerHealthCheck{ Path: "/health", Port: 8080, }, - expected: expected{ - err: false, - value: "http://backend2:8080/health", - }, + expError: false, + expTarget: "http://backend2:8080/health", + expHostname: "backend2:8080", + expMethod: http.MethodGet, }, { desc: "scheme override", - serverURL: "https://backend1:80", - options: Options{ + targetURL: "https://backend1:80", + config: dynamic.ServerHealthCheck{ Scheme: "http", Path: "/test", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1:80/test", - }, + expError: false, + expTarget: "http://backend1:80/test", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, { desc: "path with param", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: "/health?powpow=do", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1:80/health?powpow=do", - }, + expError: false, + expTarget: "http://backend1:80/health?powpow=do", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, { desc: "path with params", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: "/health?powpow=do&do=powpow", Port: 0, }, - expected: expected{ - err: false, - value: "http://backend1:80/health?powpow=do&do=powpow", - }, + expError: false, + expTarget: "http://backend1:80/health?powpow=do&do=powpow", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, { desc: "path with invalid path", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: ":", Port: 0, }, - expected: expected{ - err: true, - value: "", - }, + expError: true, + expTarget: "", + expHostname: "backend1:80", + expMethod: http.MethodGet, }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - backend := NewBackendConfig(test.options, "backendName") - - u := testhelpers.MustParseURL(test.serverURL) - - req, err := backend.newRequest(u) - - if test.expected.err { - require.Error(t, err) - assert.Nil(t, nil) - } else { - require.NoError(t, err, "failed to create new backend request") - require.NotNil(t, req) - assert.Equal(t, test.expected.value, req.URL.String()) - } - }) - } -} - -func TestRequestOptions(t *testing.T) { - testCases := []struct { - desc string - serverURL string - options Options - expectedHostname string - expectedHeader string - expectedMethod string - }{ { desc: "override hostname", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Hostname: "myhost", Path: "/", }, - expectedHostname: "myhost", - expectedHeader: "", - expectedMethod: http.MethodGet, + expTarget: "http://backend1:80/", + expHostname: "myhost", + expHeader: "", + expMethod: http.MethodGet, }, { desc: "not override hostname", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Hostname: "", Path: "/", }, - expectedHostname: "backend1:80", - expectedHeader: "", - expectedMethod: http.MethodGet, + expTarget: "http://backend1:80/", + expHostname: "backend1:80", + expHeader: "", + expMethod: http.MethodGet, }, { desc: "custom header", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Headers: map[string]string{"Custom-Header": "foo"}, Hostname: "", Path: "/", }, - expectedHostname: "backend1:80", - expectedHeader: "foo", - expectedMethod: http.MethodGet, + expTarget: "http://backend1:80/", + expHostname: "backend1:80", + expHeader: "foo", + expMethod: http.MethodGet, }, { desc: "custom header with hostname override", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Headers: map[string]string{"Custom-Header": "foo"}, Hostname: "myhost", Path: "/", }, - expectedHostname: "myhost", - expectedHeader: "foo", - expectedMethod: http.MethodGet, + expTarget: "http://backend1:80/", + expHostname: "myhost", + expHeader: "foo", + expMethod: http.MethodGet, }, { desc: "custom method", - serverURL: "http://backend1:80", - options: Options{ + targetURL: "http://backend1:80", + config: dynamic.ServerHealthCheck{ Path: "/", Method: http.MethodHead, }, - expectedHostname: "backend1:80", - expectedMethod: http.MethodHead, + expTarget: "http://backend1:80/", + expHostname: "backend1:80", + expMethod: http.MethodHead, }, } @@ -359,259 +193,215 @@ func TestRequestOptions(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - backend := NewBackendConfig(test.options, "backendName") - - u, err := url.Parse(test.serverURL) - require.NoError(t, err) + shc := ServiceHealthChecker{config: &test.config} - req, err := backend.newRequest(u) - require.NoError(t, err, "failed to create new backend request") + u := testhelpers.MustParseURL(test.targetURL) + req, err := shc.newRequest(context.Background(), u) - req = backend.setRequestOptions(req) + if test.expError { + require.Error(t, err) + assert.Nil(t, req) + } else { + require.NoError(t, err, "failed to create new request") + require.NotNil(t, req) - assert.Equal(t, "http://backend1:80/", req.URL.String()) - assert.Equal(t, test.expectedHostname, req.Host) - assert.Equal(t, test.expectedHeader, req.Header.Get("Custom-Header")) - assert.Equal(t, test.expectedMethod, req.Method) + assert.Equal(t, test.expTarget, req.URL.String()) + assert.Equal(t, test.expHeader, req.Header.Get("Custom-Header")) + assert.Equal(t, test.expHostname, req.Host) + assert.Equal(t, test.expMethod, req.Method) + } }) } } -func TestBalancers_Servers(t *testing.T) { - server1, err := url.Parse("http://foo.com") - require.NoError(t, err) - - balancer1, err := roundrobin.New(nil) - require.NoError(t, err) - - err = balancer1.UpsertServer(server1) - require.NoError(t, err) - - server2, err := url.Parse("http://foo.com") - require.NoError(t, err) - - balancer2, err := roundrobin.New(nil) - require.NoError(t, err) - - err = balancer2.UpsertServer(server2) - require.NoError(t, err) - - balancers := Balancers([]Balancer{balancer1, balancer2}) - - want, err := url.Parse("http://foo.com") - require.NoError(t, err) - - assert.Len(t, balancers.Servers(), 1) - assert.Equal(t, want, balancers.Servers()[0]) -} - -func TestBalancers_UpsertServer(t *testing.T) { - balancer1, err := roundrobin.New(nil) - require.NoError(t, err) - - balancer2, err := roundrobin.New(nil) - require.NoError(t, err) - - want, err := url.Parse("http://foo.com") - require.NoError(t, err) - - balancers := Balancers([]Balancer{balancer1, balancer2}) - - err = balancers.UpsertServer(want) - require.NoError(t, err) - - assert.Len(t, balancer1.Servers(), 1) - assert.Equal(t, want, balancer1.Servers()[0]) - - assert.Len(t, balancer2.Servers(), 1) - assert.Equal(t, want, balancer2.Servers()[0]) -} - -func TestBalancers_RemoveServer(t *testing.T) { - server, err := url.Parse("http://foo.com") - require.NoError(t, err) - - balancer1, err := roundrobin.New(nil) - require.NoError(t, err) - - err = balancer1.UpsertServer(server) - require.NoError(t, err) - - balancer2, err := roundrobin.New(nil) - require.NoError(t, err) - - err = balancer2.UpsertServer(server) - require.NoError(t, err) - - balancers := Balancers([]Balancer{balancer1, balancer2}) - - err = balancers.RemoveServer(server) - require.NoError(t, err) - - assert.Empty(t, balancer1.Servers()) - assert.Empty(t, balancer2.Servers()) -} - -type testLoadBalancer struct { - // RWMutex needed due to parallel test execution: Both the system-under-test - // and the test assertions reference the counters. - *sync.RWMutex - numRemovedServers int - numUpsertedServers int - servers []*url.URL - // options is just to make sure that LBStatusUpdater forwards options on Upsert to its BalancerHandler - options []roundrobin.ServerOption -} - -func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // noop -} - -func (lb *testLoadBalancer) RemoveServer(u *url.URL) error { - lb.Lock() - defer lb.Unlock() - lb.numRemovedServers++ - lb.removeServer(u) - return nil -} - -func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - lb.Lock() - defer lb.Unlock() - lb.numUpsertedServers++ - lb.servers = append(lb.servers, u) - lb.options = append(lb.options, options...) - return nil -} - -func (lb *testLoadBalancer) Servers() []*url.URL { - return lb.servers -} - -func (lb *testLoadBalancer) Options() []roundrobin.ServerOption { - return lb.options -} - -func (lb *testLoadBalancer) removeServer(u *url.URL) { - var i int - var serverURL *url.URL - found := false - for i, serverURL = range lb.servers { - if *serverURL == *u { - found = true - break - } - } - if !found { - return - } - - lb.servers = append(lb.servers[:i], lb.servers[i+1:]...) -} - -func newTestServer(done func(), healthSequence []int) *httptest.Server { - handler := &testHandler{ - done: done, - healthSequence: healthSequence, - } - return httptest.NewServer(handler) -} - -// ServeHTTP returns HTTP response codes following a status sequences. -// It calls the given 'done' function once all request health indicators have been depleted. -func (th *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if len(th.healthSequence) == 0 { - panic("received unexpected request") - } - - w.WriteHeader(th.healthSequence[0]) - - th.healthSequence = th.healthSequence[1:] - if len(th.healthSequence) == 0 { - th.done() - } -} - -func TestLBStatusUpdater(t *testing.T) { - lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}} - svInfo := &runtime.ServiceInfo{} - lbsu := NewLBStatusUpdater(lb, svInfo, nil) - newServer, err := url.Parse("http://foo.com") - assert.NoError(t, err) - err = lbsu.UpsertServer(newServer, roundrobin.Weight(1)) - assert.NoError(t, err) - assert.Len(t, lbsu.Servers(), 1) - assert.Len(t, lbsu.BalancerHandler.(*testLoadBalancer).Options(), 1) - statuses := svInfo.GetAllStatus() - assert.Len(t, statuses, 1) - for k, v := range statuses { - assert.Equal(t, newServer.String(), k) - assert.Equal(t, serverUp, v) - break - } - err = lbsu.RemoveServer(newServer) - assert.NoError(t, err) - assert.Empty(t, lbsu.Servers()) - statuses = svInfo.GetAllStatus() - assert.Len(t, statuses, 1) - for k, v := range statuses { - assert.Equal(t, newServer.String(), k) - assert.Equal(t, serverDown, v) - break - } -} - -func TestNotFollowingRedirects(t *testing.T) { +func TestServiceHealthChecker_checkHealthHTTP_NotFollowingRedirects(t *testing.T) { redirectServerCalled := false redirectTestServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { redirectServerCalled = true })) defer redirectTestServer.Close() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(dynamic.DefaultHealthCheckTimeout)) defer cancel() server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Add("location", redirectTestServer.URL) rw.WriteHeader(http.StatusSeeOther) - cancel() })) defer server.Close() - lb := &testLoadBalancer{ - RWMutex: &sync.RWMutex{}, - servers: []*url.URL{testhelpers.MustParseURL(server.URL)}, + config := &dynamic.ServerHealthCheck{ + Path: "/path", + FollowRedirects: Bool(false), + Interval: dynamic.DefaultHealthCheckInterval, + Timeout: dynamic.DefaultHealthCheckTimeout, } + healthChecker := NewServiceHealthChecker(ctx, nil, config, nil, nil, http.DefaultTransport, nil) - backend := NewBackendConfig(Options{ - Path: "/path", - Interval: healthCheckInterval, - Timeout: healthCheckTimeout, - LB: lb, - FollowRedirects: false, - }, "backendName") + err := healthChecker.checkHealthHTTP(ctx, testhelpers.MustParseURL(server.URL)) + require.NoError(t, err) - collectingMetrics := &testhelpers.CollectingGauge{} - check := HealthCheck{ - Backends: make(map[string]*BackendConfig), - metrics: metricsHealthcheck{serverUpGauge: collectingMetrics}, + assert.False(t, redirectServerCalled, "HTTP redirect must not be followed") +} + +func TestServiceHealthChecker_Launch(t *testing.T) { + testCases := []struct { + desc string + mode string + status int + server StartTestServer + expNumRemovedServers int + expNumUpsertedServers int + expGaugeValue float64 + targetStatus string + }{ + { + desc: "healthy server staying healthy", + server: newHTTPServer(http.StatusOK), + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server staying healthy, with custom code status check", + server: newHTTPServer(http.StatusNotFound), + status: http.StatusNotFound, + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server staying healthy (StatusNoContent)", + server: newHTTPServer(http.StatusNoContent), + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server staying healthy (StatusPermanentRedirect)", + server: newHTTPServer(http.StatusPermanentRedirect), + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server becoming sick", + server: newHTTPServer(http.StatusServiceUnavailable), + expNumRemovedServers: 1, + expNumUpsertedServers: 0, + expGaugeValue: 0, + targetStatus: runtime.StatusDown, + }, + { + desc: "healthy server becoming sick, with custom code status check", + server: newHTTPServer(http.StatusOK), + status: http.StatusServiceUnavailable, + expNumRemovedServers: 1, + expNumUpsertedServers: 0, + expGaugeValue: 0, + targetStatus: runtime.StatusDown, + }, + { + desc: "healthy server toggling to sick and back to healthy", + server: newHTTPServer(http.StatusServiceUnavailable, http.StatusOK), + expNumRemovedServers: 1, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy server toggling to healthy and go to sick", + server: newHTTPServer(http.StatusOK, http.StatusServiceUnavailable), + expNumRemovedServers: 1, + expNumUpsertedServers: 1, + expGaugeValue: 0, + targetStatus: runtime.StatusDown, + }, + { + desc: "healthy grpc server staying healthy", + mode: "grpc", + server: newGRPCServer(healthpb.HealthCheckResponse_SERVING), + expNumRemovedServers: 0, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, + { + desc: "healthy grpc server becoming sick", + mode: "grpc", + server: newGRPCServer(healthpb.HealthCheckResponse_NOT_SERVING), + expNumRemovedServers: 1, + expNumUpsertedServers: 0, + expGaugeValue: 0, + targetStatus: runtime.StatusDown, + }, + { + desc: "healthy grpc server toggling to sick and back to healthy", + mode: "grpc", + server: newGRPCServer(healthpb.HealthCheckResponse_NOT_SERVING, healthpb.HealthCheckResponse_SERVING), + expNumRemovedServers: 1, + expNumUpsertedServers: 1, + expGaugeValue: 1, + targetStatus: runtime.StatusUp, + }, } - wg := sync.WaitGroup{} - wg.Add(1) + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + // The context is passed to the health check and + // canonically canceled by the test server once all expected requests have been received. + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + targetURL, timeout := test.server.Start(t, cancel) + + lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}} + + config := &dynamic.ServerHealthCheck{ + Mode: test.mode, + Status: test.status, + Path: "/path", + Interval: ptypes.Duration(500 * time.Millisecond), + Timeout: ptypes.Duration(499 * time.Millisecond), + } + + gauge := &testhelpers.CollectingGauge{} + serviceInfo := &runtime.ServiceInfo{} + hc := NewServiceHealthChecker(ctx, &MetricsMock{gauge}, config, lb, serviceInfo, http.DefaultTransport, map[string]*url.URL{"test": targetURL}) + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + hc.Launch(ctx) + wg.Done() + }() + + select { + case <-time.After(timeout): + t.Fatal("test did not complete in time") + case <-ctx.Done(): + wg.Wait() + } - go func() { - check.execute(ctx, backend) - wg.Done() - }() + lb.Lock() + defer lb.Unlock() - timeout := time.Duration(int(healthCheckInterval) + 500) - select { - case <-time.After(timeout): - t.Fatal("test did not complete in time") - case <-ctx.Done(): - wg.Wait() + assert.Equal(t, test.expNumRemovedServers, lb.numRemovedServers, "removed servers") + assert.Equal(t, test.expNumUpsertedServers, lb.numUpsertedServers, "upserted servers") + assert.Equal(t, test.expGaugeValue, gauge.GaugeValue, "ServerUp Gauge") + assert.Equal(t, serviceInfo.GetAllStatus(), map[string]string{targetURL.String(): test.targetStatus}) + }) } +} - assert.False(t, redirectServerCalled, "HTTP redirect must not be followed") +func Bool(b bool) *bool { + return &b } diff --git a/pkg/healthcheck/mock_test.go b/pkg/healthcheck/mock_test.go new file mode 100644 index 0000000000..d2a9b52e11 --- /dev/null +++ b/pkg/healthcheck/mock_test.go @@ -0,0 +1,182 @@ +package healthcheck + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" + + gokitmetrics "github.com/go-kit/kit/metrics" + "github.com/stretchr/testify/assert" + "github.com/traefik/traefik/v2/pkg/config/dynamic" + "github.com/traefik/traefik/v2/pkg/testhelpers" + "google.golang.org/grpc" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +type StartTestServer interface { + Start(t *testing.T, done func()) (*url.URL, time.Duration) +} + +type Status interface { + ~int | ~int32 +} + +type HealthSequence[T Status] struct { + sequenceMu sync.Mutex + sequence []T +} + +func (s *HealthSequence[T]) Pop() T { + s.sequenceMu.Lock() + defer s.sequenceMu.Unlock() + + stat := s.sequence[0] + + s.sequence = s.sequence[1:] + + return stat +} + +func (s *HealthSequence[T]) IsEmpty() bool { + s.sequenceMu.Lock() + defer s.sequenceMu.Unlock() + + return len(s.sequence) == 0 +} + +type GRPCServer struct { + status HealthSequence[healthpb.HealthCheckResponse_ServingStatus] + done func() +} + +func newGRPCServer(healthSequence ...healthpb.HealthCheckResponse_ServingStatus) *GRPCServer { + gRPCService := &GRPCServer{ + status: HealthSequence[healthpb.HealthCheckResponse_ServingStatus]{ + sequence: healthSequence, + }, + } + + return gRPCService +} + +func (s *GRPCServer) Check(_ context.Context, _ *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + if s.status.IsEmpty() { + s.done() + return &healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_SERVICE_UNKNOWN, + }, nil + } + stat := s.status.Pop() + + return &healthpb.HealthCheckResponse{ + Status: stat, + }, nil +} + +func (s *GRPCServer) Watch(_ *healthpb.HealthCheckRequest, server healthpb.Health_WatchServer) error { + if s.status.IsEmpty() { + s.done() + return server.Send(&healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_SERVICE_UNKNOWN, + }) + } + stat := s.status.Pop() + + return server.Send(&healthpb.HealthCheckResponse{ + Status: stat, + }) +} + +func (s *GRPCServer) Start(t *testing.T, done func()) (*url.URL, time.Duration) { + t.Helper() + + listener, err := net.Listen("tcp4", "127.0.0.1:0") + assert.NoError(t, err) + t.Cleanup(func() { _ = listener.Close() }) + + server := grpc.NewServer() + t.Cleanup(server.Stop) + + s.done = done + + healthpb.RegisterHealthServer(server, s) + + go func() { + err := server.Serve(listener) + assert.NoError(t, err) + }() + + // Make test timeout dependent on number of expected requests, health check interval, and a safety margin. + return testhelpers.MustParseURL("http://" + listener.Addr().String()), time.Duration(len(s.status.sequence)*int(dynamic.DefaultHealthCheckInterval) + 500) +} + +type HTTPServer struct { + status HealthSequence[int] + done func() +} + +func newHTTPServer(healthSequence ...int) *HTTPServer { + handler := &HTTPServer{ + status: HealthSequence[int]{ + sequence: healthSequence, + }, + } + + return handler +} + +// ServeHTTP returns HTTP response codes following a status sequences. +// It calls the given 'done' function once all request health indicators have been depleted. +func (s *HTTPServer) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + if s.status.IsEmpty() { + s.done() + // This ensures that the health-checker will handle the context cancellation error before receiving the HTTP response. + time.Sleep(500 * time.Millisecond) + return + } + + stat := s.status.Pop() + + w.WriteHeader(stat) +} + +func (s *HTTPServer) Start(t *testing.T, done func()) (*url.URL, time.Duration) { + t.Helper() + + s.done = done + + ts := httptest.NewServer(s) + t.Cleanup(ts.Close) + + // Make test timeout dependent on number of expected requests, health check interval, and a safety margin. + return testhelpers.MustParseURL(ts.URL), time.Duration(len(s.status.sequence)*int(dynamic.DefaultHealthCheckInterval) + 500) +} + +type testLoadBalancer struct { + // RWMutex needed due to parallel test execution: Both the system-under-test + // and the test assertions reference the counters. + *sync.RWMutex + numRemovedServers int + numUpsertedServers int +} + +func (lb *testLoadBalancer) SetStatus(ctx context.Context, childName string, up bool) { + if up { + lb.numUpsertedServers++ + } else { + lb.numRemovedServers++ + } +} + +type MetricsMock struct { + Gauge gokitmetrics.Gauge +} + +func (m *MetricsMock) ServiceServerUpGauge() gokitmetrics.Gauge { + return m.Gauge +} diff --git a/pkg/middlewares/emptybackendhandler/empty_backend_handler.go b/pkg/middlewares/emptybackendhandler/empty_backend_handler.go deleted file mode 100644 index 3331bf3e36..0000000000 --- a/pkg/middlewares/emptybackendhandler/empty_backend_handler.go +++ /dev/null @@ -1,34 +0,0 @@ -package emptybackendhandler - -import ( - "net/http" - - "github.com/traefik/traefik/v2/pkg/healthcheck" -) - -// EmptyBackend is a middleware that checks whether the current Backend -// has at least one active Server in respect to the healthchecks and if this -// is not the case, it will stop the middleware chain and respond with 503. -type emptyBackend struct { - healthcheck.BalancerStatusHandler -} - -// New creates a new EmptyBackend middleware. -func New(lb healthcheck.BalancerStatusHandler) http.Handler { - return &emptyBackend{BalancerStatusHandler: lb} -} - -// ServeHTTP responds with 503 when there is no active Server and otherwise -// invokes the next handler in the middleware chain. -func (e *emptyBackend) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if len(e.BalancerStatusHandler.Servers()) != 0 { - e.BalancerStatusHandler.ServeHTTP(rw, req) - return - } - - rw.WriteHeader(http.StatusServiceUnavailable) - if _, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } -} diff --git a/pkg/middlewares/emptybackendhandler/empty_backend_handler_test.go b/pkg/middlewares/emptybackendhandler/empty_backend_handler_test.go deleted file mode 100644 index 299fe9c316..0000000000 --- a/pkg/middlewares/emptybackendhandler/empty_backend_handler_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package emptybackendhandler - -import ( - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/traefik/traefik/v2/pkg/testhelpers" - "github.com/vulcand/oxy/v2/roundrobin" -) - -func TestEmptyBackendHandler(t *testing.T) { - testCases := []struct { - amountServer int - expectedStatusCode int - }{ - { - amountServer: 0, - expectedStatusCode: http.StatusServiceUnavailable, - }, - { - amountServer: 1, - expectedStatusCode: http.StatusOK, - }, - } - - for _, test := range testCases { - t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) { - t.Parallel() - - handler := New(&healthCheckLoadBalancer{amountServer: test.amountServer}) - - recorder := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - - handler.ServeHTTP(recorder, req) - - assert.Equal(t, test.expectedStatusCode, recorder.Result().StatusCode) - }) - } -} - -type healthCheckLoadBalancer struct { - amountServer int -} - -func (lb *healthCheckLoadBalancer) RegisterStatusUpdater(fn func(up bool)) error { - return nil -} - -func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -func (lb *healthCheckLoadBalancer) Servers() []*url.URL { - servers := make([]*url.URL, lb.amountServer) - for range lb.amountServer { - servers = append(servers, testhelpers.MustParseURL("http://localhost")) - } - return servers -} - -func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error { - return nil -} - -func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - return nil -} - -func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) { - return 0, false -} - -func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) { - return nil, nil -} - -func (lb *healthCheckLoadBalancer) Next() http.Handler { - return nil -} diff --git a/pkg/provider/kubernetes/crd/kubernetes_test.go b/pkg/provider/kubernetes/crd/kubernetes_test.go index 6724e5c631..1037b6c678 100644 --- a/pkg/provider/kubernetes/crd/kubernetes_test.go +++ b/pkg/provider/kubernetes/crd/kubernetes_test.go @@ -3573,7 +3573,7 @@ func TestLoadIngressRoutes(t *testing.T) { }, }, PassHostHeader: Bool(false), - ResponseForwarding: &dynamic.ResponseForwarding{FlushInterval: "10s"}, + ResponseForwarding: &dynamic.ResponseForwarding{FlushInterval: ptypes.Duration(10 * time.Second)}, }, }, }, diff --git a/pkg/provider/kv/kv_test.go b/pkg/provider/kv/kv_test.go index 535a48bab6..b0b9653637 100644 --- a/pkg/provider/kv/kv_test.go +++ b/pkg/provider/kv/kv_test.go @@ -42,14 +42,14 @@ func Test_buildConfiguration(t *testing.T) { "traefik/http/routers/Router1/service": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/path": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/port": "42", - "traefik/http/services/Service01/loadBalancer/healthCheck/interval": "foobar", - "traefik/http/services/Service01/loadBalancer/healthCheck/timeout": "foobar", + "traefik/http/services/Service01/loadBalancer/healthCheck/interval": "1s", + "traefik/http/services/Service01/loadBalancer/healthCheck/timeout": "1s", "traefik/http/services/Service01/loadBalancer/healthCheck/hostname": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/headers/name0": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/headers/name1": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/scheme": "foobar", "traefik/http/services/Service01/loadBalancer/healthCheck/followredirects": "true", - "traefik/http/services/Service01/loadBalancer/responseForwarding/flushInterval": "foobar", + "traefik/http/services/Service01/loadBalancer/responseForwarding/flushInterval": "100ms", "traefik/http/services/Service01/loadBalancer/passHostHeader": "true", "traefik/http/services/Service01/loadBalancer/sticky/cookie/name": "foobar", "traefik/http/services/Service01/loadBalancer/sticky/cookie/secure": "true", @@ -644,8 +644,8 @@ func Test_buildConfiguration(t *testing.T) { Scheme: "foobar", Path: "foobar", Port: 42, - Interval: "foobar", - Timeout: "foobar", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foobar", FollowRedirects: func(v bool) *bool { return &v }(true), Headers: map[string]string{ @@ -655,7 +655,7 @@ func Test_buildConfiguration(t *testing.T) { }, PassHostHeader: func(v bool) *bool { return &v }(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foobar", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, }, }, diff --git a/pkg/redactor/redactor_config_test.go b/pkg/redactor/redactor_config_test.go index 368f2d9564..019eb5247a 100644 --- a/pkg/redactor/redactor_config_test.go +++ b/pkg/redactor/redactor_config_test.go @@ -82,8 +82,8 @@ func init() { Scheme: "foo", Path: "foo", Port: 42, - Interval: "foo", - Timeout: "foo", + Interval: ptypes.Duration(time.Second), + Timeout: ptypes.Duration(time.Second), Hostname: "foo", FollowRedirects: boolPtr(true), Headers: map[string]string{ @@ -92,7 +92,7 @@ func init() { }, PassHostHeader: boolPtr(true), ResponseForwarding: &dynamic.ResponseForwarding{ - FlushInterval: "foo", + FlushInterval: ptypes.Duration(100 * time.Millisecond), }, ServersTransport: "foo", Servers: []dynamic.Server{ diff --git a/pkg/redactor/testdata/anonymized-dynamic-config.json b/pkg/redactor/testdata/anonymized-dynamic-config.json index 8339f58c3c..6b967f91f3 100644 --- a/pkg/redactor/testdata/anonymized-dynamic-config.json +++ b/pkg/redactor/testdata/anonymized-dynamic-config.json @@ -75,8 +75,8 @@ "scheme": "foo", "path": "foo", "port": 42, - "interval": "foo", - "timeout": "foo", + "interval": "1s", + "timeout": "1s", "hostname": "xxxx", "followRedirects": true, "headers": { @@ -85,7 +85,7 @@ }, "passHostHeader": true, "responseForwarding": { - "flushInterval": "foo" + "flushInterval": "100ms" }, "serversTransport": "foo" } @@ -475,4 +475,4 @@ } } } -} \ No newline at end of file +} diff --git a/pkg/redactor/testdata/secured-dynamic-config.json b/pkg/redactor/testdata/secured-dynamic-config.json index 352421884d..4447b1698e 100644 --- a/pkg/redactor/testdata/secured-dynamic-config.json +++ b/pkg/redactor/testdata/secured-dynamic-config.json @@ -75,8 +75,8 @@ "scheme": "foo", "path": "foo", "port": 42, - "interval": "foo", - "timeout": "foo", + "interval": "1s", + "timeout": "1s", "hostname": "foo", "followRedirects": true, "headers": { @@ -85,7 +85,7 @@ }, "passHostHeader": true, "responseForwarding": { - "flushInterval": "foo" + "flushInterval": "100ms" }, "serversTransport": "foo" } @@ -483,4 +483,4 @@ } } } -} \ No newline at end of file +} diff --git a/pkg/server/router/router.go b/pkg/server/router/router.go index 0b88b0fafa..eabf5cbc84 100644 --- a/pkg/server/router/router.go +++ b/pkg/server/router/router.go @@ -31,7 +31,7 @@ type middlewareBuilder interface { type serviceManager interface { BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) - LaunchHealthCheck() + LaunchHealthCheck(ctx context.Context) } // Manager A route/router manager. diff --git a/pkg/server/router/router_test.go b/pkg/server/router/router_test.go index d5fba5af1a..2d47f64c92 100644 --- a/pkg/server/router/router_test.go +++ b/pkg/server/router/router_test.go @@ -8,10 +8,12 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/containous/alice" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/metrics" @@ -482,7 +484,7 @@ func TestRuntimeConfiguration(t *testing.T) { }, }, HealthCheck: &dynamic.ServerHealthCheck{ - Interval: "500ms", + Interval: ptypes.Duration(500 * time.Millisecond), Path: "/health", }, }, diff --git a/pkg/server/routerfactory.go b/pkg/server/routerfactory.go index 6b7b80ff4b..b4cb5109f2 100644 --- a/pkg/server/routerfactory.go +++ b/pkg/server/routerfactory.go @@ -77,7 +77,7 @@ func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string handlersNonTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, false) handlersTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, true) - serviceManager.LaunchHealthCheck() + serviceManager.LaunchHealthCheck(ctx) // TCP svcTCPManager := tcp.NewManager(rtConf) diff --git a/pkg/server/service/internalhandler.go b/pkg/server/service/internalhandler.go index 4bd6fd2585..e5c3ab26d7 100644 --- a/pkg/server/service/internalhandler.go +++ b/pkg/server/service/internalhandler.go @@ -10,7 +10,7 @@ import ( type serviceManager interface { BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) - LaunchHealthCheck() + LaunchHealthCheck(ctx context.Context) } // InternalHandlers is the internal HTTP handlers builder. diff --git a/pkg/server/service/loadbalancer/wrr/wrr.go b/pkg/server/service/loadbalancer/wrr/wrr.go index ecca04b624..720945ffa3 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr.go +++ b/pkg/server/service/loadbalancer/wrr/wrr.go @@ -42,7 +42,7 @@ type Balancer struct { curDeadline float64 // status is a record of which child services of the Balancer are healthy, keyed // by name of child service. A service is initially added to the map when it is - // created via AddService, and it is later removed or added to the map as needed, + // created via Add, and it is later removed or added to the map as needed, // through the SetStatus method. status map[string]struct{} // updaters is the list of hooks that are run (to update the Balancer @@ -51,11 +51,11 @@ type Balancer struct { } // New creates a new load balancer. -func New(sticky *dynamic.Sticky, hc *dynamic.HealthCheck) *Balancer { +func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer { balancer := &Balancer{ status: make(map[string]struct{}), handlerMap: make(map[string]*namedHandler), - wantsHealthCheck: hc != nil, + wantsHealthCheck: wantHealthCheck, } if sticky != nil && sticky.Cookie != nil { balancer.stickyCookie = &stickyCookie{ @@ -155,10 +155,7 @@ func (b *Balancer) nextServer() (*namedHandler, error) { b.handlersMu.Lock() defer b.handlersMu.Unlock() - if len(b.handlers) == 0 { - return nil, errors.New("no servers in the pool") - } - if len(b.status) == 0 { + if len(b.handlers) == 0 || len(b.status) == 0 { return nil, errNoAvailableServer } @@ -224,9 +221,9 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { server.ServeHTTP(w, req) } -// AddService adds a handler. +// Add adds a handler. // A handler with a non-positive weight is ignored. -func (b *Balancer) AddService(name string, handler http.Handler, weight *int) { +func (b *Balancer) Add(name string, handler http.Handler, weight *int) { w := 1 if weight != nil { w = *weight diff --git a/pkg/server/service/loadbalancer/wrr/wrr_test.go b/pkg/server/service/loadbalancer/wrr/wrr_test.go index 32068504e3..19f2cf38ef 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr_test.go +++ b/pkg/server/service/loadbalancer/wrr/wrr_test.go @@ -10,31 +10,15 @@ import ( "github.com/traefik/traefik/v2/pkg/config/dynamic" ) -func Int(v int) *int { return &v } - -type responseRecorder struct { - *httptest.ResponseRecorder - save map[string]int - sequence []string - status []int -} - -func (r *responseRecorder) WriteHeader(statusCode int) { - r.save[r.Header().Get("server")]++ - r.sequence = append(r.sequence, r.Header().Get("server")) - r.status = append(r.status, statusCode) - r.ResponseRecorder.WriteHeader(statusCode) -} - func TestBalancer(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(3)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(1)) @@ -49,23 +33,23 @@ func TestBalancer(t *testing.T) { } func TestBalancerNoService(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) recorder := httptest.NewRecorder() balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode) + assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) } func TestBalancerOneServerZeroWeight(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} for range 3 { @@ -80,13 +64,13 @@ type key string const serviceName key = "serviceName" func TestBalancerNoServiceUp(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) }), Int(1)) @@ -100,14 +84,14 @@ func TestBalancerNoServiceUp(t *testing.T) { } func TestBalancerOneServerDown(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) }), Int(1)) balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false) @@ -121,14 +105,14 @@ func TestBalancerOneServerDown(t *testing.T) { } func TestBalancerDownThenUp(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(1)) @@ -150,35 +134,35 @@ func TestBalancerDownThenUp(t *testing.T) { } func TestBalancerPropagate(t *testing.T) { - balancer1 := New(nil, &dynamic.HealthCheck{}) + balancer1 := New(nil, true) - balancer1.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer1.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer1.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer1.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer2 := New(nil, &dynamic.HealthCheck{}) - balancer2.AddService("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer2 := New(nil, true) + balancer2.Add("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "third") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer2.AddService("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer2.Add("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "fourth") rw.WriteHeader(http.StatusOK) }), Int(1)) - topBalancer := New(nil, &dynamic.HealthCheck{}) - topBalancer.AddService("balancer1", balancer1, Int(1)) + topBalancer := New(nil, true) + topBalancer.Add("balancer1", balancer1, Int(1)) _ = balancer1.RegisterStatusUpdater(func(up bool) { topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer1", up) // TODO(mpl): if test gets flaky, add channel or something here to signal that // propagation is done, and wait on it before sending request. }) - topBalancer.AddService("balancer2", balancer2, Int(1)) + topBalancer.Add("balancer2", balancer2, Int(1)) _ = balancer2.RegisterStatusUpdater(func(up bool) { topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer2", up) }) @@ -223,28 +207,28 @@ func TestBalancerPropagate(t *testing.T) { } func TestBalancerAllServersZeroWeight(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) - balancer.AddService("test2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) + balancer.Add("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) + balancer.Add("test2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) recorder := httptest.NewRecorder() balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode) + assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) } func TestSticky(t *testing.T) { balancer := New(&dynamic.Sticky{ Cookie: &dynamic.Cookie{Name: "test"}, - }, nil) + }, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(2)) @@ -270,14 +254,14 @@ func TestSticky(t *testing.T) { func TestSticky_FallBack(t *testing.T) { balancer := New(&dynamic.Sticky{ Cookie: &dynamic.Cookie{Name: "test"}, - }, nil) + }, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "first") rw.WriteHeader(http.StatusOK) }), Int(1)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "second") rw.WriteHeader(http.StatusOK) }), Int(2)) @@ -299,21 +283,21 @@ func TestSticky_FallBack(t *testing.T) { // TestBalancerBias makes sure that the WRR algorithm spreads elements evenly right from the start, // and that it does not "over-favor" the high-weighted ones with a biased start-up regime. func TestBalancerBias(t *testing.T) { - balancer := New(nil, nil) + balancer := New(nil, false) - balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "A") rw.WriteHeader(http.StatusOK) }), Int(11)) - balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "B") rw.WriteHeader(http.StatusOK) }), Int(3)) recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 14 { + for i := 0; i < 14; i++ { balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) } @@ -321,3 +305,19 @@ func TestBalancerBias(t *testing.T) { assert.Equal(t, wantSequence, recorder.sequence) } + +func Int(v int) *int { return &v } + +type responseRecorder struct { + *httptest.ResponseRecorder + save map[string]int + sequence []string + status []int +} + +func (r *responseRecorder) WriteHeader(statusCode int) { + r.save[r.Header().Get("server")]++ + r.sequence = append(r.sequence, r.Header().Get("server")) + r.status = append(r.status, statusCode) + r.ResponseRecorder.WriteHeader(statusCode) +} diff --git a/pkg/server/service/proxy.go b/pkg/server/service/proxy.go index 257c4ce0d4..267b97b707 100644 --- a/pkg/server/service/proxy.go +++ b/pkg/server/service/proxy.go @@ -3,7 +3,6 @@ package service import ( "context" "errors" - "fmt" "io" "net" "net/http" @@ -12,8 +11,6 @@ import ( "strings" "time" - ptypes "github.com/traefik/paerser/types" - "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/log" "golang.org/x/net/http/httpguts" ) @@ -24,100 +21,104 @@ const StatusClientClosedRequest = 499 // StatusClientClosedRequestText non-standard HTTP status for client disconnection. const StatusClientClosedRequestText = "Client Closed Request" -func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwarding, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) (http.Handler, error) { - var flushInterval ptypes.Duration - if responseForwarding != nil { - err := flushInterval.Set(responseForwarding.FlushInterval) - if err != nil { - return nil, fmt.Errorf("error creating flush interval: %w", err) - } - } - if flushInterval == 0 { - flushInterval = ptypes.Duration(100 * time.Millisecond) +func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler { + return &httputil.ReverseProxy{ + Director: directorBuilder(target, passHostHeader), + Transport: roundTripper, + FlushInterval: flushInterval, + BufferPool: bufferPool, + ErrorHandler: errorHandler, } +} + +func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Request) { + return func(outReq *http.Request) { + outReq.URL.Scheme = target.Scheme + outReq.URL.Host = target.Host - proxy := &httputil.ReverseProxy{ - Director: func(outReq *http.Request) { - u := outReq.URL - if outReq.RequestURI != "" { - parsedURL, err := url.ParseRequestURI(outReq.RequestURI) - if err == nil { - u = parsedURL - } + u := outReq.URL + if outReq.RequestURI != "" { + parsedURL, err := url.ParseRequestURI(outReq.RequestURI) + if err == nil { + u = parsedURL } + } - outReq.URL.Path = u.Path - outReq.URL.RawPath = u.RawPath - // If a plugin/middleware adds semicolons in query params, they should be urlEncoded. - outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&") - outReq.RequestURI = "" // Outgoing request should not have RequestURI + outReq.URL.Path = u.Path + outReq.URL.RawPath = u.RawPath + // If a plugin/middleware adds semicolons in query params, they should be urlEncoded. + outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&") + outReq.RequestURI = "" // Outgoing request should not have RequestURI - outReq.Proto = "HTTP/1.1" - outReq.ProtoMajor = 1 - outReq.ProtoMinor = 1 + outReq.Proto = "HTTP/1.1" + outReq.ProtoMajor = 1 + outReq.ProtoMinor = 1 - // Do not pass client Host header unless optsetter PassHostHeader is set. - if passHostHeader != nil && !*passHostHeader { - outReq.Host = outReq.URL.Host - } - - // Even if the websocket RFC says that headers should be case-insensitive, - // some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept, - // Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. - // https://tools.ietf.org/html/rfc6455#page-20 - if isWebSocketUpgrade(outReq) { - outReq.Header["Sec-WebSocket-Key"] = outReq.Header["Sec-Websocket-Key"] - outReq.Header["Sec-WebSocket-Extensions"] = outReq.Header["Sec-Websocket-Extensions"] - outReq.Header["Sec-WebSocket-Accept"] = outReq.Header["Sec-Websocket-Accept"] - outReq.Header["Sec-WebSocket-Protocol"] = outReq.Header["Sec-Websocket-Protocol"] - outReq.Header["Sec-WebSocket-Version"] = outReq.Header["Sec-Websocket-Version"] - delete(outReq.Header, "Sec-Websocket-Key") - delete(outReq.Header, "Sec-Websocket-Extensions") - delete(outReq.Header, "Sec-Websocket-Accept") - delete(outReq.Header, "Sec-Websocket-Protocol") - delete(outReq.Header, "Sec-Websocket-Version") - } - }, - Transport: roundTripper, - FlushInterval: time.Duration(flushInterval), - BufferPool: bufferPool, - ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) { - statusCode := http.StatusInternalServerError + // Do not pass client Host header unless optsetter PassHostHeader is set. + if !passHostHeader { + outReq.Host = outReq.URL.Host + } - switch { - case errors.Is(err, io.EOF): - statusCode = http.StatusBadGateway - case errors.Is(err, context.Canceled): - statusCode = StatusClientClosedRequest - default: - var netErr net.Error - if errors.As(err, &netErr) { - if netErr.Timeout() { - statusCode = http.StatusGatewayTimeout - } else { - statusCode = http.StatusBadGateway - } - } - } + cleanWebSocketHeaders(outReq) + } +} - log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err) - w.WriteHeader(statusCode) - _, werr := w.Write([]byte(statusText(statusCode))) - if werr != nil { - log.Debugf("Error while writing status code", werr) - } - }, +// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive, +// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept, +// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. +// https://tools.ietf.org/html/rfc6455#page-20 +func cleanWebSocketHeaders(req *http.Request) { + if !isWebSocketUpgrade(req) { + return } - return proxy, nil + req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"] + delete(req.Header, "Sec-Websocket-Key") + + req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"] + delete(req.Header, "Sec-Websocket-Extensions") + + req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"] + delete(req.Header, "Sec-Websocket-Accept") + + req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"] + delete(req.Header, "Sec-Websocket-Protocol") + + req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"] + delete(req.Header, "Sec-Websocket-Version") } func isWebSocketUpgrade(req *http.Request) bool { - if !httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { - return false + return httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") && + strings.EqualFold(req.Header.Get("Upgrade"), "websocket") +} + +func errorHandler(w http.ResponseWriter, req *http.Request, err error) { + statusCode := http.StatusInternalServerError + + switch { + case errors.Is(err, io.EOF): + statusCode = http.StatusBadGateway + case errors.Is(err, context.Canceled): + statusCode = StatusClientClosedRequest + default: + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + statusCode = http.StatusGatewayTimeout + } else { + statusCode = http.StatusBadGateway + } + } } - return strings.EqualFold(req.Header.Get("Upgrade"), "websocket") + logger := log.FromContext(req.Context()) + logger.WithError(err).Debugf("%d %s", statusCode, statusText(statusCode)) + + w.WriteHeader(statusCode) + if _, werr := w.Write([]byte(statusText(statusCode))); werr != nil { + logger.WithError(werr).Debug("Error while writing status code") + } } func statusText(statusCode int) string { diff --git a/pkg/server/service/proxy_test.go b/pkg/server/service/proxy_test.go index 567ba61057..afaff38809 100644 --- a/pkg/server/service/proxy_test.go +++ b/pkg/server/service/proxy_test.go @@ -28,7 +28,7 @@ func BenchmarkProxy(b *testing.B) { req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil) pool := newBufferPool() - handler, _ := buildProxy(Bool(false), nil, &staticTransport{res}, pool) + handler := buildSingleHostProxy(req.URL, false, 0, &staticTransport{res}, pool) b.ReportAllocs() for range b.N { diff --git a/pkg/server/service/proxy_websocket_test.go b/pkg/server/service/proxy_websocket_test.go index 108133c37e..1d2293aeae 100644 --- a/pkg/server/service/proxy_websocket_test.go +++ b/pkg/server/service/proxy_websocket_test.go @@ -21,9 +21,6 @@ import ( func Bool(v bool) *bool { return &v } func TestWebSocketTCPClose(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - errChan := make(chan error, 1) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -42,7 +39,7 @@ func TestWebSocketTCPClose(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) proxyAddr := proxy.Listener.Addr().String() _, conn, err := newWebsocketRequest( @@ -61,10 +58,6 @@ func TestWebSocketTCPClose(t *testing.T) { } func TestWebSocketPingPong(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{ HandshakeTimeout: 10 * time.Second, CheckOrigin: func(*http.Request) bool { @@ -86,17 +79,10 @@ func TestWebSocketPingPong(t *testing.T) { _, _, _ = ws.ReadMessage() }) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() - proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.URL = parseURI(t, srv.URL) - f.ServeHTTP(w, req) - })) - defer proxy.Close() - + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) serverAddr := proxy.Listener.Addr().String() headers := http.Header{} @@ -127,9 +113,6 @@ func TestWebSocketPingPong(t *testing.T) { } func TestWebSocketEcho(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { msg := make([]byte, 4) @@ -145,17 +128,10 @@ func TestWebSocketEcho(t *testing.T) { require.NoError(t, err) })) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() - proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.URL = parseURI(t, srv.URL) - f.ServeHTTP(w, req) - })) - defer proxy.Close() - + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) serverAddr := proxy.Listener.Addr().String() headers := http.Header{} @@ -193,10 +169,6 @@ func TestWebSocketPassHost(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - f, err := buildProxy(Bool(test.passHost), nil, http.DefaultTransport, nil) - - require.NoError(t, err) - mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { req := conn.Request() @@ -208,7 +180,7 @@ func TestWebSocketPassHost(t *testing.T) { } msg := make([]byte, 4) - _, err = conn.Read(msg) + _, err := conn.Read(msg) require.NoError(t, err) fmt.Println(string(msg)) @@ -219,16 +191,10 @@ func TestWebSocketPassHost(t *testing.T) { require.NoError(t, err) })) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() - proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.URL = parseURI(t, srv.URL) - f.ServeHTTP(w, req) - })) - defer proxy.Close() + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) serverAddr := proxy.Listener.Addr().String() @@ -252,9 +218,6 @@ func TestWebSocketPassHost(t *testing.T) { } func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} @@ -277,7 +240,7 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() @@ -293,9 +256,6 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { } func TestWebSocketRequestWithOrigin(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) @@ -316,11 +276,11 @@ func TestWebSocketRequestWithOrigin(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() - _, err = newWebsocketRequest( + _, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("echo"), @@ -339,9 +299,6 @@ func TestWebSocketRequestWithOrigin(t *testing.T) { } func TestWebSocketRequestWithQueryParams(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) @@ -363,7 +320,7 @@ func TestWebSocketRequestWithQueryParams(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() @@ -379,18 +336,14 @@ func TestWebSocketRequestWithQueryParams(t *testing.T) { } func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { conn.Close() })) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() + f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil) proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = parseURI(t, srv.URL) w.Header().Set("HEADER-KEY", "HEADER-VALUE") @@ -403,6 +356,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) + conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", err, resp) defer conn.Close() @@ -411,9 +365,6 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { } func TestWebSocketRequestWithEncodedChar(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) @@ -435,7 +386,7 @@ func TestWebSocketRequestWithEncodedChar(t *testing.T) { })) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() @@ -451,18 +402,14 @@ func TestWebSocketRequestWithEncodedChar(t *testing.T) { } func TestWebSocketUpgradeFailed(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - mux := http.NewServeMux() mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusBadRequest) }) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() + f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil) proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path @@ -501,9 +448,6 @@ func TestWebSocketUpgradeFailed(t *testing.T) { } func TestForwardsWebsocketTraffic(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { _, err := conn.Write([]byte("ok")) @@ -512,12 +456,10 @@ func TestForwardsWebsocketTraffic(t *testing.T) { err = conn.Close() require.NoError(t, err) })) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - })) + srv := httptest.NewServer(mux) defer srv.Close() - proxy := createProxyWithForwarder(t, f, srv.URL) + proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() @@ -557,15 +499,12 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { srv := createTLSWebsocketServer() defer srv.Close() - forwarderWithoutTLSConfig, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) - require.NoError(t, err) - - proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL) + proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxyWithoutTLSConfig.Close() proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String() - _, err = newWebsocketRequest( + _, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), @@ -576,10 +515,8 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { transport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - forwarderWithTLSConfig, err := buildProxy(Bool(true), nil, transport, nil) - require.NoError(t, err) - proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL) + proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, transport) defer proxyWithTLSConfig.Close() proxyAddr = proxyWithTLSConfig.Listener.Addr().String() @@ -597,10 +534,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { defaultTransport := http.DefaultTransport.(*http.Transport).Clone() defaultTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(Bool(true), nil, defaultTransport, nil) - require.NoError(t, err) - - proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL) + proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, srv.URL, defaultTransport) defer proxyWithTLSConfig.Close() proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String() @@ -705,15 +639,19 @@ func parseURI(t *testing.T, uri string) *url.URL { return out } -func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server { +func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server { t.Helper() - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + u := parseURI(t, uri) + proxy := buildSingleHostProxy(u, true, 0, transport, nil) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path // Set new backend URL - req.URL = parseURI(t, url) + req.URL = u req.URL.Path = path proxy.ServeHTTP(w, req) })) + t.Cleanup(srv.Close) + return srv } diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index cc43e76a2a..d611325a18 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -4,36 +4,28 @@ import ( "context" "errors" "fmt" + "hash/fnv" "math/rand" "net/http" "net/http/httputil" "net/url" "reflect" + "strings" "time" - "github.com/containous/alice" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/healthcheck" "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/metrics" "github.com/traefik/traefik/v2/pkg/middlewares/accesslog" - "github.com/traefik/traefik/v2/pkg/middlewares/emptybackendhandler" metricsMiddle "github.com/traefik/traefik/v2/pkg/middlewares/metrics" - "github.com/traefik/traefik/v2/pkg/middlewares/pipelining" "github.com/traefik/traefik/v2/pkg/safe" "github.com/traefik/traefik/v2/pkg/server/cookie" "github.com/traefik/traefik/v2/pkg/server/provider" "github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/failover" "github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/mirror" "github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/wrr" - "github.com/vulcand/oxy/v2/roundrobin" - "github.com/vulcand/oxy/v2/roundrobin/stickycookie" -) - -const ( - defaultHealthCheckInterval = 30 * time.Second - defaultHealthCheckTimeout = 5 * time.Second ) const defaultMaxBodySize int64 = -1 @@ -43,6 +35,19 @@ type RoundTripperGetter interface { Get(name string) (http.RoundTripper, error) } +// Manager The service manager. +type Manager struct { + routinePool *safe.Pool + metricsRegistry metrics.Registry + bufferPool httputil.BufferPool + roundTripperManager RoundTripperGetter + + services map[string]http.Handler + configs map[string]*runtime.ServiceInfo + healthCheckers map[string]*healthcheck.ServiceHealthChecker + rand *rand.Rand // For the initial shuffling of load-balancers. +} + // NewManager creates a new Manager. func NewManager(configs map[string]*runtime.ServiceInfo, metricsRegistry metrics.Registry, routinePool *safe.Pool, roundTripperManager RoundTripperGetter) *Manager { return &Manager{ @@ -50,27 +55,13 @@ func NewManager(configs map[string]*runtime.ServiceInfo, metricsRegistry metrics metricsRegistry: metricsRegistry, bufferPool: newBufferPool(), roundTripperManager: roundTripperManager, - balancers: make(map[string]healthcheck.Balancers), + services: make(map[string]http.Handler), configs: configs, + healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker), rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } -// Manager The service manager. -type Manager struct { - routinePool *safe.Pool - metricsRegistry metrics.Registry - bufferPool httputil.BufferPool - roundTripperManager RoundTripperGetter - // balancers is the map of all Balancers, keyed by service name. - // There is one Balancer per service handler, and there is one service handler per reference to a service - // (e.g. if 2 routers refer to the same service name, 2 service handlers are created), - // which is why there is not just one Balancer per service name. - balancers map[string]healthcheck.Balancers - configs map[string]*runtime.ServiceInfo - rand *rand.Rand // For the initial shuffling of load-balancers. -} - // BuildHTTP Creates a http.Handler for a service configuration. func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) { ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName)) @@ -78,11 +69,20 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H serviceName = provider.GetQualifiedName(ctx, serviceName) ctx = provider.AddInContext(ctx, serviceName) + handler, ok := m.services[serviceName] + if ok { + return handler, nil + } + conf, ok := m.configs[serviceName] if !ok { return nil, fmt.Errorf("the service %q does not exist", serviceName) } + if conf.Status == runtime.StatusDisabled { + return nil, errors.New(strings.Join(conf.Err, ", ")) + } + value := reflect.ValueOf(*conf.Service) var count int for i := range value.NumField() { @@ -101,7 +101,7 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H switch { case conf.LoadBalancer != nil: var err error - lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer) + lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf) if err != nil { conf.AddError(err, true) return nil, err @@ -133,6 +133,8 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H return nil, sErr } + m.services[serviceName] = lb + return lb, nil } @@ -214,14 +216,14 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, config.Sticky.Cookie.Name = cookie.GetName(config.Sticky.Cookie.Name, serviceName) } - balancer := wrr.New(config.Sticky, config.HealthCheck) + balancer := wrr.New(config.Sticky, config.HealthCheck != nil) for _, service := range shuffle(config.Services, m.rand) { serviceHandler, err := m.BuildHTTP(ctx, service.Name) if err != nil { return nil, err } - balancer.AddService(service.Name, serviceHandler, service.Weight) + balancer.Add(service.Name, serviceHandler, service.Weight) if config.HealthCheck == nil { continue @@ -245,201 +247,90 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, return balancer, nil } -func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName string, service *dynamic.ServersLoadBalancer) (http.Handler, error) { - if service.PassHostHeader == nil { - defaultPassHostHeader := true - service.PassHostHeader = &defaultPassHostHeader - } +func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName string, info *runtime.ServiceInfo) (http.Handler, error) { + service := info.LoadBalancer - if len(service.ServersTransport) > 0 { - service.ServersTransport = provider.GetQualifiedName(ctx, service.ServersTransport) - } + logger := log.FromContext(ctx) + logger.Debug("Creating load-balancer") - roundTripper, err := m.roundTripperManager.Get(service.ServersTransport) - if err != nil { - return nil, err + // TODO: should we keep this config value as Go is now handling stream response correctly? + flushInterval := dynamic.DefaultFlushInterval + if service.ResponseForwarding != nil { + flushInterval = service.ResponseForwarding.FlushInterval } - fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, roundTripper, m.bufferPool) - if err != nil { - return nil, err + if len(service.ServersTransport) > 0 { + service.ServersTransport = provider.GetQualifiedName(ctx, service.ServersTransport) } - alHandler := func(next http.Handler) (http.Handler, error) { - return accesslog.NewFieldHandler(next, accesslog.ServiceName, serviceName, accesslog.AddServiceFields), nil - } - chain := alice.New() - if m.metricsRegistry != nil && m.metricsRegistry.IsSvcEnabled() { - chain = chain.Append(metricsMiddle.WrapServiceHandler(ctx, m.metricsRegistry, serviceName)) + if service.Sticky != nil && service.Sticky.Cookie != nil { + service.Sticky.Cookie.Name = cookie.GetName(service.Sticky.Cookie.Name, serviceName) } - handler, err := chain.Append(alHandler).Then(pipelining.New(ctx, fwd, "pipelining")) - if err != nil { - return nil, err + // We make sure that the PassHostHeader value is defined to avoid panics. + passHostHeader := dynamic.DefaultPassHostHeader + if service.PassHostHeader != nil { + passHostHeader = *service.PassHostHeader } - balancer, err := m.getLoadBalancer(ctx, serviceName, service, handler) + roundTripper, err := m.roundTripperManager.Get(service.ServersTransport) if err != nil { return nil, err } - // TODO rename and checks - m.balancers[serviceName] = append(m.balancers[serviceName], balancer) + lb := wrr.New(service.Sticky, service.HealthCheck != nil) + healthCheckTargets := make(map[string]*url.URL) - // Empty (backend with no servers) - return emptybackendhandler.New(balancer), nil -} - -// LaunchHealthCheck launches the health checks. -func (m *Manager) LaunchHealthCheck() { - backendConfigs := make(map[string]*healthcheck.BackendConfig) + for _, server := range shuffle(service.Servers, m.rand) { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(server.URL)) // this will never return an error. - for serviceName, balancers := range m.balancers { - ctx := log.With(context.Background(), log.Str(log.ServiceName, serviceName)) + proxyName := fmt.Sprintf("%x", hasher.Sum(nil)) - service := m.configs[serviceName].LoadBalancer - - // Health Check - hcOpts := buildHealthCheckOptions(ctx, balancers, serviceName, service.HealthCheck) - if hcOpts == nil { - continue - } - hcOpts.Transport, _ = m.roundTripperManager.Get(service.ServersTransport) - log.FromContext(ctx).Debugf("Setting up healthcheck for service %s with %s", serviceName, *hcOpts) - - backendConfigs[serviceName] = healthcheck.NewBackendConfig(*hcOpts, serviceName) - } - - healthcheck.GetHealthCheck(m.metricsRegistry).SetBackendsConfiguration(context.Background(), backendConfigs) -} - -func buildHealthCheckOptions(ctx context.Context, lb healthcheck.Balancer, backend string, hc *dynamic.ServerHealthCheck) *healthcheck.Options { - if hc == nil { - return nil - } - - logger := log.FromContext(ctx) - - if hc.Path == "" { - logger.Errorf("Ignoring heath check configuration for '%s': no path provided", backend) - return nil - } - - interval := defaultHealthCheckInterval - if hc.Interval != "" { - intervalOverride, err := time.ParseDuration(hc.Interval) - switch { - case err != nil: - logger.Errorf("Illegal health check interval for '%s': %s", backend, err) - case intervalOverride <= 0: - logger.Errorf("Health check interval smaller than zero for service '%s'", backend) - default: - interval = intervalOverride - } - } - - timeout := defaultHealthCheckTimeout - if hc.Timeout != "" { - timeoutOverride, err := time.ParseDuration(hc.Timeout) - switch { - case err != nil: - logger.Errorf("Illegal health check timeout for backend '%s': %s", backend, err) - case timeoutOverride <= 0: - logger.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend) - default: - timeout = timeoutOverride + target, err := url.Parse(server.URL) + if err != nil { + return nil, fmt.Errorf("error parsing server URL %s: %w", server.URL, err) } - } - - followRedirects := true - if hc.FollowRedirects != nil { - followRedirects = *hc.FollowRedirects - } - - return &healthcheck.Options{ - Scheme: hc.Scheme, - Path: hc.Path, - Method: hc.Method, - Port: hc.Port, - Interval: interval, - Timeout: timeout, - LB: lb, - Hostname: hc.Hostname, - Headers: hc.Headers, - FollowRedirects: followRedirects, - } -} - -func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, service *dynamic.ServersLoadBalancer, fwd http.Handler) (healthcheck.BalancerStatusHandler, error) { - logger := log.FromContext(ctx) - logger.Debug("Creating load-balancer") - var options []roundrobin.LBOption + logger.WithField(log.ServerName, proxyName).WithField("target", target).Info("Creating server") - var cookieName string - if service.Sticky != nil && service.Sticky.Cookie != nil { - cookieName = cookie.GetName(service.Sticky.Cookie.Name, serviceName) + proxy := buildSingleHostProxy(target, passHostHeader, time.Duration(flushInterval), roundTripper, m.bufferPool) - opts := roundrobin.CookieOptions{ - HTTPOnly: service.Sticky.Cookie.HTTPOnly, - Secure: service.Sticky.Cookie.Secure, - SameSite: convertSameSite(service.Sticky.Cookie.SameSite), - } + proxy = accesslog.NewFieldHandler(proxy, accesslog.ServiceURL, target.String(), nil) + proxy = accesslog.NewFieldHandler(proxy, accesslog.ServiceAddr, target.Host, nil) + proxy = accesslog.NewFieldHandler(proxy, accesslog.ServiceName, serviceName, nil) - // Sticky Cookie Value - cv, err := stickycookie.NewFallbackValue(&stickycookie.RawValue{}, &stickycookie.HashValue{}) - if err != nil { - return nil, err + if m.metricsRegistry != nil && m.metricsRegistry.IsSvcEnabled() { + proxy = metricsMiddle.NewServiceMiddleware(ctx, proxy, m.metricsRegistry, serviceName) } - options = append(options, roundrobin.EnableStickySession(roundrobin.NewStickySessionWithOptions(cookieName, opts).SetCookieValue(cv))) + lb.Add(proxyName, proxy, nil) - logger.Debugf("Sticky session cookie name: %v", cookieName) - } + // servers are considered UP by default. + info.UpdateServerStatus(target.String(), runtime.StatusUp) - lb, err := roundrobin.New(fwd, options...) - if err != nil { - return nil, err + healthCheckTargets[proxyName] = target } - lbsu := healthcheck.NewLBStatusUpdater(lb, m.configs[serviceName], service.HealthCheck) - if err := m.upsertServers(ctx, lbsu, service.Servers); err != nil { - return nil, fmt.Errorf("error configuring load balancer for service %s: %w", serviceName, err) + if service.HealthCheck != nil { + m.healthCheckers[serviceName] = healthcheck.NewServiceHealthChecker( + ctx, + m.metricsRegistry, + service.HealthCheck, + lb, + info, + roundTripper, + healthCheckTargets, + ) } - return lbsu, nil -} - -func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []dynamic.Server) error { - logger := log.FromContext(ctx) - - for name, srv := range shuffle(servers, m.rand) { - u, err := url.Parse(srv.URL) - if err != nil { - return fmt.Errorf("error parsing server URL %s: %w", srv.URL, err) - } - - logger.WithField(log.ServerName, name).Debugf("Creating server %d %s", name, u) - - if err := lb.UpsertServer(u, roundrobin.Weight(1)); err != nil { - return fmt.Errorf("error adding server %s to load balancer: %w", srv.URL, err) - } - - // TODO Handle Metrics - } - return nil + return lb, nil } -func convertSameSite(sameSite string) http.SameSite { - switch sameSite { - case "none": - return http.SameSiteNoneMode - case "lax": - return http.SameSiteLaxMode - case "strict": - return http.SameSiteStrictMode - default: - return 0 +// LaunchHealthCheck launches the health checks. +func (m *Manager) LaunchHealthCheck(ctx context.Context) { + for serviceName, hc := range m.healthCheckers { + go hc.Launch(log.With(ctx, log.Str(log.ServiceName, serviceName))) } } diff --git a/pkg/server/service/service_test.go b/pkg/server/service/service_test.go index b761b410e9..8fa27c429b 100644 --- a/pkg/server/service/service_test.go +++ b/pkg/server/service/service_test.go @@ -24,63 +24,6 @@ func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) { panic("implement me") } -func TestGetLoadBalancer(t *testing.T) { - sm := Manager{} - - testCases := []struct { - desc string - serviceName string - service *dynamic.ServersLoadBalancer - fwd http.Handler - expectError bool - }{ - { - desc: "Fails when provided an invalid URL", - serviceName: "test", - service: &dynamic.ServersLoadBalancer{ - Servers: []dynamic.Server{ - { - URL: ":", - }, - }, - }, - fwd: &MockForwarder{}, - expectError: true, - }, - { - desc: "Succeeds when there are no servers", - serviceName: "test", - service: &dynamic.ServersLoadBalancer{}, - fwd: &MockForwarder{}, - expectError: false, - }, - { - desc: "Succeeds when sticky.cookie is set", - serviceName: "test", - service: &dynamic.ServersLoadBalancer{ - Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}}, - }, - fwd: &MockForwarder{}, - expectError: false, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - handler, err := sm.getLoadBalancer(context.Background(), test.serviceName, test.service, test.fwd) - if test.expectError { - require.Error(t, err) - assert.Nil(t, handler) - } else { - require.NoError(t, err) - assert.NotNil(t, handler) - } - }) - } -} - func TestGetLoadBalancerServiceHandler(t *testing.T) { sm := NewManager(nil, nil, nil, &RoundTripperManager{ roundTrippers: map[string]http.RoundTripper{ @@ -336,7 +279,8 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service) + serviceInfo := runtime.ServiceInfo{Service: &dynamic.Service{LoadBalancer: test.service}} + handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, &serviceInfo) assert.NoError(t, err) assert.NotNil(t, handler) @@ -414,7 +358,8 @@ func Test1xxResponses(t *testing.T) { }, }, } - handler, err := sm.getLoadBalancerServiceHandler(context.Background(), "foobar", config) + serviceInfo := runtime.ServiceInfo{Service: &dynamic.Service{LoadBalancer: config}} + handler, err := sm.getLoadBalancerServiceHandler(context.Background(), "foobar", &serviceInfo) assert.NoError(t, err) frontend := httptest.NewServer(handler)