diff --git a/.gitignore b/.gitignore index 98b2c19f..846cab3a 100644 --- a/.gitignore +++ b/.gitignore @@ -51,8 +51,9 @@ __pycache__/ # C extensions *.so -# Test binary, build with `go test -c` +# Test binaries *.test +sso-devproxy # Output of the go coverage tool, specifically when used with LiteIDE *.out diff --git a/Makefile b/Makefile index 182f11b5..f93658f3 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ version := "v1.0.0" commit := $(shell git rev-parse --short HEAD) -build: dist/sso-auth dist/sso-proxy +build: dist/sso-auth dist/sso-proxy dist/sso-devproxy dist/sso-auth: mkdir -p dist @@ -12,6 +12,10 @@ dist/sso-proxy: mkdir -p dist go build -o dist/sso-proxy ./cmd/sso-proxy +dist/sso-devproxy: + mkdir -p dist + go build -o dist/sso-devproxy ./cmd/sso-devproxy + test: ./scripts/test diff --git a/cmd/sso-devproxy/main.go b/cmd/sso-devproxy/main.go new file mode 100644 index 00000000..b739a3db --- /dev/null +++ b/cmd/sso-devproxy/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "fmt" + "net/http" + "os" + + "github.com/buzzfeed/sso/internal/devproxy" + log "github.com/buzzfeed/sso/internal/pkg/logging" + "github.com/kelseyhightower/envconfig" +) + +func init() { + log.SetServiceName("sso-dev-proxy") +} + +func main() { + logger := log.NewLogEntry() + + opts := devproxy.NewOptions() + + err := envconfig.Process("", opts) + if err != nil { + logger.Error(err, "error loading in env vars") + os.Exit(1) + } + + err = opts.Validate() + if err != nil { + logger.Error(err, "error validing options") + os.Exit(1) + } + + proxy, err := devproxy.NewDevProxy(opts) + if err != nil { + logger.Error(err, "error creating devproxy") + os.Exit(1) + } + + s := &http.Server{ + Addr: fmt.Sprintf(":%d", opts.Port), + ReadTimeout: opts.TCPReadTimeout, + WriteTimeout: opts.TCPWriteTimeout, + Handler: devproxy.NewLoggingHandler(os.Stdout, proxy.Handler(), opts.RequestLogging), + } + logger.Fatal(s.ListenAndServe()) +} diff --git a/internal/devproxy/dev_config.go b/internal/devproxy/dev_config.go new file mode 100644 index 00000000..7400c322 --- /dev/null +++ b/internal/devproxy/dev_config.go @@ -0,0 +1,386 @@ +package devproxy + +import ( + "fmt" + "net/url" + "regexp" + "strings" + "time" + + "github.com/18F/hmacauth" + "github.com/imdario/mergo" + "gopkg.in/yaml.v2" +) + +const ( + simple = "simple" + rewrite = "rewrite" +) + +var ( + space = regexp.MustCompile(`\s+`) +) + +// ServiceConfig represents the configuration for a given service +type ServiceConfig struct { + Service string `yaml:"service"` + ClusterConfigs map[string]*UpstreamConfig `yaml:",inline"` +} + +// SimpleRoute contains a FromURL and ToURL used to construct simple routes in the reverse proxy. +type SimpleRoute struct { + FromURL *url.URL + ToURL *url.URL +} + +// RewriteRoute contains a FromRegex and ToTemplate used to construct rewrite routes in the reverse proxy. +type RewriteRoute struct { + FromRegex *regexp.Regexp + ToTemplate *url.URL +} + +// UpstreamConfig represents the configuration for a given cluster in a given service +type UpstreamConfig struct { + Service string + + RouteConfig RouteConfig `yaml:",inline"` + + ExtraRoutes []*RouteConfig `yaml:"extra_routes"` + + // Generated at Parse Time + Route interface{} // note: :/ + HMACAuth hmacauth.HmacAuth + Timeout time.Duration + FlushInterval time.Duration + HeaderOverrides map[string]string + TLSSkipVerify bool + SkipRequestSigning bool + User string + Groups string + Email string +} + +// RouteConfig maps to the yaml config fields, +// * "from" - the domain that will be used to access the service +// * "to" - the cname of the proxied service (this tells sso proxy where to proxy requests that come in on the from field) +type RouteConfig struct { + From string `yaml:"from"` + To string `yaml:"to"` + Type string `yaml:"type"` + Options *OptionsConfig `yaml:"options"` + UserInfo *UserInfo `yaml:"user_info"` +} + +//UserInfo is going to be injected into the header +type UserInfo struct { + User string `yaml:"user"` + Groups string `yaml:"groups"` + Email string `yaml:"email"` +} + +// OptionsConfig maps to the yaml config fields: +// * header_overrides - overrides any heads set either by sso proxy itself or upstream applications. +// This can be useful for modifying browser security headers. +// * skip_auth_regex - skips authentication for paths matching these regular expressions. +// * allowed_groups - optional list of authorized google groups that can access the service. +// * timeout - duration before timing out request. +// * flush_interval - interval at which the proxy should flush data to the browser +type OptionsConfig struct { + HeaderOverrides map[string]string `yaml:"header_overrides"` + Timeout time.Duration `yaml:"timeout"` + FlushInterval time.Duration `yaml:"flush_interval"` + SkipRequestSigning bool `yaml:"skip_request_signing"` +} + +// ErrParsingConfig is an error specific to config parsing. +type ErrParsingConfig struct { + Message string + Err error +} + +// Error() implements the error interface, returning a string representation of the error. +func (e *ErrParsingConfig) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s error=%s", e.Message, e.Err) + } + return e.Message +} + +func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[string]string) ([]*UpstreamConfig, error) { + // We fill in all templated values and resolve overrides + rawTemplated := resolveTemplates(raw, configVars) + + serviceConfigs, err := parseServiceConfigs(rawTemplated) + if err != nil { + return nil, err + } + + // we don't set this to the len(serviceConfig) since not all service configs + // are configured for all clusters, leaving nil tail pointers in the slice. + configs := make([]*UpstreamConfig, 0) + // resolve overrides + for _, service := range serviceConfigs { + proxy, err := resolveUpstreamConfig(service, cluster) + if err != nil { + return nil, err + } + + // if we don't resolve a upstream config, this cluster is not configured for this upstream + // so the proxy struct will be nil and we skip adding it to our running config + if proxy != nil { + configs = append(configs, proxy) + } + } + + extraRoutes := make([]*UpstreamConfig, 0) + for _, proxy := range configs { + if len(proxy.ExtraRoutes) == 0 { + continue + } + + for _, extra := range proxy.ExtraRoutes { + resolvedProxy, err := resolveExtraRoute(extra, proxy) + if err != nil { + return nil, err + } + extraRoutes = append(extraRoutes, resolvedProxy) + } + // for completeness, we set this to nil now that we've processed extra routes + proxy.ExtraRoutes = nil + } + + configs = append(configs, extraRoutes...) + + // We verify the config has necessary values + for _, proxy := range configs { + err := validateUpstreamConfig(proxy) + if err != nil { + return nil, err + } + } + + // We compose the URLs for all our finalized domains + for _, proxy := range configs { + switch proxy.RouteConfig.Type { + case simple, "": + route, err := simpleRoute(scheme, proxy.RouteConfig) + if err != nil { + return nil, err + } + proxy.Route = route + case rewrite: + route, err := rewriteRoute(scheme, proxy.RouteConfig) + if err != nil { + return nil, err + } + proxy.Route = route + default: + return nil, &ErrParsingConfig{ + Message: fmt.Sprintf("unknown routing config type %q", proxy.RouteConfig.Type), + Err: nil, + } + } + } + + // We validate OptionsConfig + for _, proxy := range configs { + err := parseOptionsConfig(proxy) + if err != nil { + return nil, err + } + err = parseUserInfoConfig(proxy) + if err != nil { + return nil, err + } + } + + for _, proxy := range configs { + key := fmt.Sprintf("%s_signing_key", proxy.Service) + signingKey, ok := configVars[key] + if !ok { + continue + } + auth, err := generateHmacAuth(signingKey) + if err != nil { + return nil, &ErrParsingConfig{ + Message: fmt.Sprintf("unable to generate hmac auth for %s", proxy.Service), + Err: err, + } + } + proxy.HMACAuth = auth + + } + + return configs, nil +} + +func rewriteRoute(scheme string, routeConfig RouteConfig) (*RewriteRoute, error) { + compiled, err := regexp.Compile(routeConfig.From) + if err != nil { + return nil, &ErrParsingConfig{ + Message: "unable to compile rewrite from regex", + Err: err, + } + } + + toURL := &url.URL{ + Scheme: scheme, + Opaque: routeConfig.To, // we use opaque since the template value may not be a parsable URL + } + + return &RewriteRoute{ + FromRegex: compiled, + ToTemplate: toURL, + }, nil +} + +func simpleRoute(scheme string, routeConfig RouteConfig) (*SimpleRoute, error) { + // url parse domain + fromURL, err := urlParse(scheme, routeConfig.From) + if err != nil { + return nil, &ErrParsingConfig{ + Message: "unable to url parse `from` parameter", + Err: err, + } + } + + // url parse to url + toURL, err := urlParse(scheme, routeConfig.To) + if err != nil { + return nil, &ErrParsingConfig{ + Message: "unable to url parse `to` parameter", + Err: err, + } + } + + return &SimpleRoute{ + FromURL: fromURL, + ToURL: toURL, + }, nil +} + +func urlParse(scheme, uri string) (*url.URL, error) { + // NOTE: This is done intentionally to add a scheme so it is valid to parse. + // + // From https://golang.org/pkg/net/url/#Parse + // > Trying to parse a hostname and path without a scheme is invalid + // > but may not necessarily return an error, due to parsing ambiguities. + if !strings.Contains(uri, "://") { + uri = fmt.Sprintf("%s://%s", scheme, uri) + } + return url.Parse(uri) +} + +func parseServiceConfigs(data []byte) ([]*ServiceConfig, error) { + serviceConfigs := make([]*ServiceConfig, 0) + err := yaml.Unmarshal(data, &serviceConfigs) + if err != nil { + return nil, &ErrParsingConfig{ + Message: "failed to parse yaml", + Err: err, + } + } + + return serviceConfigs, err +} + +func resolveExtraRoute(routeConfig *RouteConfig, src *UpstreamConfig) (*UpstreamConfig, error) { + dst := &UpstreamConfig{RouteConfig: *routeConfig} + + err := mergo.Merge(dst, *src) + if err != nil { + return nil, err + } + + dst.ExtraRoutes = nil + + return dst, nil +} + +func resolveUpstreamConfig(service *ServiceConfig, override string) (*UpstreamConfig, error) { + dst, dstOk := service.ClusterConfigs["default"] + src, srcOk := service.ClusterConfigs[override] + + if !(dstOk || srcOk) { + // no default or cluster is configured, which we allow + return nil, nil + } + + if dst == nil { + dst = &UpstreamConfig{} + } + + if src == nil { + src = &UpstreamConfig{} + } + + err := mergo.Merge(dst, *src, mergo.WithOverride) + if err != nil { + return nil, err + } + + dst.Service = cleanWhiteSpace(service.Service) + return dst, nil +} + +func validateUpstreamConfig(proxy *UpstreamConfig) error { + if proxy.Service == "" { + return &ErrParsingConfig{ + Message: "missing `service` parameter", + } + } + + if proxy.RouteConfig.From == "" { + return &ErrParsingConfig{ + Message: "missing `from` parameter", + } + } + + if proxy.RouteConfig.To == "" { + return &ErrParsingConfig{ + Message: "missing `to` parameter", + } + } + + return nil +} + +func resolveTemplates(raw []byte, templateVars map[string]string) []byte { + rawString := string(raw) + for k, v := range templateVars { + templated := fmt.Sprintf("{{%s}}", k) + rawString = strings.Replace(rawString, templated, v, -1) + } + return []byte(rawString) +} + +func parseOptionsConfig(proxy *UpstreamConfig) error { + if proxy.RouteConfig.Options == nil { + return nil + } + + proxy.Timeout = proxy.RouteConfig.Options.Timeout + proxy.FlushInterval = proxy.RouteConfig.Options.FlushInterval + proxy.HeaderOverrides = proxy.RouteConfig.Options.HeaderOverrides + proxy.SkipRequestSigning = proxy.RouteConfig.Options.SkipRequestSigning + proxy.RouteConfig.Options = nil + + return nil +} + +func parseUserInfoConfig(proxy *UpstreamConfig) error { + if proxy.RouteConfig.UserInfo == nil { + return nil + } + + proxy.User = proxy.RouteConfig.UserInfo.User + proxy.Groups = proxy.RouteConfig.UserInfo.Groups + proxy.Email = proxy.RouteConfig.UserInfo.Email + + return nil +} + +func cleanWhiteSpace(s string) string { + // This trims all white space from a service name and collapses all remaining space to `_` + return space.ReplaceAllString(strings.TrimSpace(s), "_") // +} diff --git a/internal/devproxy/devproxy.go b/internal/devproxy/devproxy.go new file mode 100644 index 00000000..43ac0d92 --- /dev/null +++ b/internal/devproxy/devproxy.go @@ -0,0 +1,474 @@ +package devproxy + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "html/template" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "regexp" + "strings" + "time" + + "github.com/18F/hmacauth" + log "github.com/buzzfeed/sso/internal/pkg/logging" +) + +// HMACSignatureHeader is the header name where the signed request header is stored. +const HMACSignatureHeader = "Gap-Signature" + +// SignatureHeaders are the headers that are valid in the request. +var SignatureHeaders = []string{ + "Content-Length", + "Content-Md5", + "Content-Type", + "Date", + "Authorization", + "X-Forwarded-User", + "X-Forwarded-Email", + "X-Forwarded-Groups", + "Cookie", +} + +const statusInvalidHost = 421 + +// DevProxy stores all the information associated with proxying the request. +type DevProxy struct { + skipAuthPreflight bool + templates *template.Template + mux map[string]*route + regexRoutes []*route + requestSigner *RequestSigner + publicCertsJSON []byte + user string + groups string + email string +} + +type route struct { + upstreamConfig *UpstreamConfig + handler http.Handler + tags []string + + // only used for ones that have regex + regex *regexp.Regexp +} + +// StateParameter holds the redirect id along with the session id. +type StateParameter struct { + SessionID string `json:"session_id"` + RedirectURI string `json:"redirect_uri"` +} + +// UpstreamProxy stores information necessary for proxying the request back to the upstream. +type UpstreamProxy struct { + name string + handler http.Handler + auth hmacauth.HmacAuth + requestSigner *RequestSigner +} + +// upstreamTransport is used to ensure that upstreams cannot override the +// security headers applied by dev_proxy +type upstreamTransport struct { + transport *http.Transport +} + +// RoundTrip round trips the request and deletes security headers before returning the response. +func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + logger := log.NewLogEntry() + logger.Error(err, "error in upstreamTransport RoundTrip") + return nil, err + } + for key := range securityHeaders { + resp.Header.Del(key) + } + return resp, err +} + +func newUpstreamTransport(insecureSkipVerify bool) *upstreamTransport { + return &upstreamTransport{ + transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: insecureSkipVerify}, + ExpectContinueTimeout: 1 * time.Second, + }, + } +} + +// ServeHTTP calls the upstream's ServeHTTP function. +func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if u.auth != nil { + u.auth.SignRequest(r) + } + if u.requestSigner != nil { + u.requestSigner.Sign(r) + } + + start := time.Now() + u.handler.ServeHTTP(w, r) + duration := time.Now().Sub(start) + + fmt.Printf("service_name:%s, duration:%s", u.name, duration) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewReverseProxy creates a reverse proxy to a specified url. +// It adds an X-Forwarded-Host header that is the request's host. +func NewReverseProxy(to *url.URL, config *UpstreamConfig) *httputil.ReverseProxy { + targetQuery := to.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = to.Scheme + req.URL.Host = to.Host + req.URL.Path = singleJoiningSlash(to.Path, req.URL.Path) + + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + if _, ok := req.Header["User-Agent"]; !ok { + // explicitly disable User-Agent so it's not set to default value + req.Header.Set("User-Agent", "") + } + } + proxy := &httputil.ReverseProxy{Director: director} + proxy.Transport = newUpstreamTransport(config.TLSSkipVerify) + dir := proxy.Director + + proxy.Director = func(req *http.Request) { + req.Header.Add("X-Forwarded-Host", req.Host) + dir(req) + req.Host = to.Host + } + + return proxy +} + +// NewRewriteReverseProxy creates a reverse proxy that is capable of creating upstream +// urls on the fly based on a from regex and a templated to field. +// It adds an X-Forwarded-Host header to the the upstream's request. +func NewRewriteReverseProxy(route *RewriteRoute, config *UpstreamConfig) *httputil.ReverseProxy { + proxy := &httputil.ReverseProxy{} + proxy.Transport = newUpstreamTransport(config.TLSSkipVerify) + proxy.Director = func(req *http.Request) { + // we do this to rewrite requests + rewritten := route.FromRegex.ReplaceAllString(req.Host, route.ToTemplate.Opaque) + + // we use to favor scheme's used in the regex, else we use the default passed in via the template + target, err := urlParse(route.ToTemplate.Scheme, rewritten) + if err != nil { + logger := log.NewLogEntry() + // we aren't in an error handling context so we have to fake it(thanks stdlib!) + logger.WithRequestHost(req.Host).WithRewriteRoute(route).Error( + err, "unable to parse and replace rewrite url") + req.URL = nil // this will raise an error in http.RoundTripper + return + } + director := httputil.NewSingleHostReverseProxy(target).Director + + req.Header.Add("X-Forwarded-Host", req.Host) + director(req) + req.Host = target.Host + } + return proxy +} + +// NewReverseProxyHandler creates a new http.Handler given a httputil.ReverseProxy +func NewReverseProxyHandler(reverseProxy *httputil.ReverseProxy, opts *Options, config *UpstreamConfig, signer *RequestSigner) (http.Handler, []string) { + upstreamProxy := &UpstreamProxy{ + name: config.Service, + handler: reverseProxy, + auth: config.HMACAuth, + requestSigner: signer, + } + + if config.SkipRequestSigning { + upstreamProxy.requestSigner = nil + } + + if config.FlushInterval != 0 { + return NewStreamingHandler(upstreamProxy, opts, config), []string{"handler:streaming"} + } + return NewTimeoutHandler(upstreamProxy, opts, config), []string{"handler:timeout"} +} + +// NewTimeoutHandler creates a new handler with a configure timeout. +func NewTimeoutHandler(handler http.Handler, opts *Options, config *UpstreamConfig) http.Handler { + timeout := opts.DefaultUpstreamTimeout + if config.Timeout != 0 { + timeout = config.Timeout + } + timeoutMsg := fmt.Sprintf( + "%s failed to respond within the %s timeout period", config.Service, timeout) + return http.TimeoutHandler(handler, timeout, timeoutMsg) +} + +// NewStreamingHandler creates a new handler capable of proxying a stream +func NewStreamingHandler(handler http.Handler, opts *Options, config *UpstreamConfig) http.Handler { + upstreamProxy := handler.(*UpstreamProxy) + reverseProxy := upstreamProxy.handler.(*httputil.ReverseProxy) + reverseProxy.FlushInterval = config.FlushInterval + return upstreamProxy +} + +func generateHmacAuth(signatureKey string) (hmacauth.HmacAuth, error) { + components := strings.Split(signatureKey, ":") + if len(components) != 2 { + return nil, fmt.Errorf("invalid signature hash:key spec") + } + + algorithm, secret := components[0], components[1] + hash, err := hmacauth.DigestNameToCryptoHash(algorithm) + if err != nil { + return nil, fmt.Errorf("unsupported signature hash algorithm: %s", algorithm) + } + auth := hmacauth.NewHmacAuth(hash, []byte(secret), HMACSignatureHeader, SignatureHeaders) + return auth, nil +} + +// NewDevProxy creates a new DevProxy struct. +func NewDevProxy(opts *Options, optFuncs ...func(*DevProxy) error) (*DevProxy, error) { + logger := log.NewLogEntry() + logger.Info("NewDevProxy...") + + // Configure the RequestSigner (used to sign requests with `Sso-Signature` header). + // Also build the `certs` static JSON-string which will be served from a public endpoint. + // The key published at this endpoint allows upstreams to decrypt the `Sso-Signature` + // header, and validate the integrity and authenticity of a request. + + certs := make(map[string]string) + var requestSigner *RequestSigner + var err error + if len(opts.RequestSigningKey) > 0 { + requestSigner, err = NewRequestSigner(opts.RequestSigningKey) + + if err != nil { + return nil, fmt.Errorf("could not build RequestSigner: %s", err) + } + id, key := requestSigner.PublicKey() + certs[id] = key + + } else { + logger.Warn("Running DevProxy without signing key. Requests will not be signed.") + } + + certsAsStr, err := json.MarshalIndent(certs, "", " ") + if err != nil { + return nil, fmt.Errorf("could not marshal public certs as JSON: %s", err) + } + + p := &DevProxy{ + // these fields make up the routing mechanism + mux: make(map[string]*route), + regexRoutes: make([]*route, 0), + templates: getTemplates(), + requestSigner: requestSigner, + publicCertsJSON: certsAsStr, + } + + for _, optFunc := range optFuncs { + err := optFunc(p) + if err != nil { + return nil, err + } + } + for _, upstreamConfig := range opts.upstreamConfigs { + p.user = upstreamConfig.User + p.email = upstreamConfig.Email + p.groups = upstreamConfig.Groups + switch route := upstreamConfig.Route.(type) { + case *SimpleRoute: + reverseProxy := NewReverseProxy(route.ToURL, upstreamConfig) + handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig, requestSigner) + p.Handle(route.FromURL.Host, handler, tags, upstreamConfig) + case *RewriteRoute: + reverseProxy := NewRewriteReverseProxy(route, upstreamConfig) + handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig, requestSigner) + p.HandleRegex(route.FromRegex, handler, tags, upstreamConfig) + default: + return nil, fmt.Errorf("unknown route type") + } + } + + return p, nil +} + +// Handler returns a http handler for an DevProxy +func (p *DevProxy) Handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/favicon.ico", p.Favicon) + mux.HandleFunc("/robots.txt", p.RobotsTxt) + mux.HandleFunc("/oauth2/v1/certs", p.Certs) + mux.HandleFunc("/", p.Proxy) + + // Global middleware, which will be applied to each request in reverse + // order as applied here (i.e., we want to validate the host _first_ when + // processing a request) + var handler http.Handler = mux + handler = p.setResponseHeaderOverrides(handler) + handler = setSecurityHeaders(handler) + handler = p.validateHost(handler) + + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + // Skip host validation for /ping requests because they hit the LB directly. + if req.URL.Path == "/ping" { + p.PingPage(rw, req) + return + } + handler.ServeHTTP(rw, req) + }) +} + +// RobotsTxt sets the User-Agent header in the response to be "Disallow" +func (p *DevProxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusOK) + fmt.Fprintf(rw, "User-agent: *\nDisallow: /") +} + +// Favicon will proxy the request as usual +func (p *DevProxy) Favicon(rw http.ResponseWriter, req *http.Request) { + err := p.setProxyHeaders(rw, req) + if err != nil { + rw.WriteHeader(http.StatusNotFound) + return + } + rw.WriteHeader(http.StatusOK) + p.Proxy(rw, req) +} + +// PingPage send back a 200 OK response. +func (p *DevProxy) PingPage(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusOK) + fmt.Fprintf(rw, "OK") +} + +// ErrorPage renders an error page with a given status code, title, and message. +func (p *DevProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) { + if p.isXMLHTTPRequest(req) { + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(code) + err := json.NewEncoder(rw).Encode(struct { + Error string `json:"error"` + }{ + Error: message, + }) + if err != nil { + io.WriteString(rw, err.Error()) + } + } else { + logger := log.NewLogEntry() + logger.WithHTTPStatus(code).WithPageTitle(title).WithPageMessage(message).Info( + "error page") + rw.WriteHeader(code) + t := struct { + Code int + Title string + Message string + }{ + Code: code, + Title: title, + Message: message, + } + p.templates.ExecuteTemplate(rw, "error.html", t) + } +} + +func (p *DevProxy) isXMLHTTPRequest(req *http.Request) bool { + return req.Header.Get("X-Requested-With") == "XMLHttpRequest" +} + +// Proxy forwards the request. +func (p *DevProxy) Proxy(rw http.ResponseWriter, req *http.Request) { + + logger := log.NewLogEntry() + p.setProxyHeaders(rw, req) + + logger.Info("Proxy...") + // We now proxy their request to the provided upstream. + route, ok := p.router(req) + if !ok { + p.UnknownHost(rw, req) + return + } + + route.handler.ServeHTTP(rw, req) +} + +// UnknownHost returns an http error for unknown or invalid hosts +func (p *DevProxy) UnknownHost(rw http.ResponseWriter, req *http.Request) { + logger := log.NewLogEntry() + + logger.WithRequestHost(req.Host).Error("unknown host") + http.Error(rw, "", statusInvalidHost) +} + +// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig +func (p *DevProxy) Handle(host string, handler http.Handler, tags []string, upstreamConfig *UpstreamConfig) { + + tags = append(tags, "route:simple") + p.mux[host] = &route{handler: handler, upstreamConfig: upstreamConfig, tags: tags} +} + +// HandleRegex constructs a route from the given regexp and matches it to the provided http.Handler and UpstreamConfig +func (p *DevProxy) HandleRegex(regex *regexp.Regexp, handler http.Handler, tags []string, upstreamConfig *UpstreamConfig) { + tags = append(tags, "route:rewrite") + p.regexRoutes = append(p.regexRoutes, &route{regex: regex, handler: handler, upstreamConfig: upstreamConfig, tags: tags}) +} + +func (p *DevProxy) setProxyHeaders(rw http.ResponseWriter, req *http.Request) (err error) { + req.Header.Set("X-Forwarded-User", p.user) + req.Header.Set("X-Forwarded-Email", p.email) + req.Header.Set("X-Forwarded-Groups", p.groups) + // req.Header.set("X-Forwarded-Access-Token", "") + return nil +} + +// router attempts to find a route for a request. If a route is successfully matched, +// it returns the route information and a bool value of `true`. If a route can not be matched, +//a nil value for the route and false bool value is returned. +func (p *DevProxy) router(req *http.Request) (*route, bool) { + route, ok := p.mux[req.Host] + if ok { + return route, true + } + + for _, route := range p.regexRoutes { + if route.regex.MatchString(req.Host) { + return route, true + } + } + + return nil, false +} + +// Certs publishes the public key necessary for upstream services to validate the digital signature +// used to sign each request. +func (p *DevProxy) Certs(rw http.ResponseWriter, _ *http.Request) { + rw.Write(p.publicCertsJSON) +} diff --git a/internal/devproxy/devproxy_test.go b/internal/devproxy/devproxy_test.go new file mode 100644 index 00000000..842a235f --- /dev/null +++ b/internal/devproxy/devproxy_test.go @@ -0,0 +1,967 @@ +package devproxy + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "regexp" + "strings" + "testing" + "time" + + "github.com/buzzfeed/sso/internal/pkg/testutil" +) + +func init() { + log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) + +} + +func testValidatorFunc(valid bool) func(*DevProxy) error { + return func(p *DevProxy) error { + return nil + } +} + +func TestNewReverseProxy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + hostname, _, _ := net.SplitHostPort(r.Host) + w.Write([]byte(hostname)) + })) + defer backend.Close() + + backendURL, _ := url.Parse(backend.URL) + backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) + backendHost := net.JoinHostPort(backendHostname, backendPort) + proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") + + proxyHandler := NewReverseProxy(proxyURL, &UpstreamConfig{TLSSkipVerify: false}) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + res, _ := http.DefaultClient.Do(getReq) + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendHostname; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +func TestNewRewriteReverseProxy(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(200) + rw.Write([]byte(req.Host)) + })) + defer upstream.Close() + + parsedUpstreamURL, err := url.Parse(upstream.URL) + if err != nil { + t.Fatalf("expected to parse upstream URL err:%q", err) + } + + route := &RewriteRoute{ + FromRegex: regexp.MustCompile("(.*)"), + ToTemplate: &url.URL{ + Scheme: parsedUpstreamURL.Scheme, + Opaque: parsedUpstreamURL.Host, + }, + } + + rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{TLSSkipVerify: false}) + + frontend := httptest.NewServer(rewriteProxy) + defer frontend.Close() + + resp, err := http.Get(frontend.URL) + if err != nil { + t.Fatalf("expected to make successful request err:%q", err) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("expected to read body err:%q", err) + } + + if string(body) != parsedUpstreamURL.Host { + t.Logf("got %v", string(body)) + t.Logf("want %v", parsedUpstreamURL.Host) + t.Fatalf("got unexpected response from upstream") + } +} + +func TestNewReverseProxyHostname(t *testing.T) { + type respStruct struct { + Host string `json:"host"` + XForwardedHost string `json:"x-forwarded-host"` + // XForwardedEmail string `json:"x-forwarded-email"` + // XForwardedUser string `json:"x-forwarded-user"` + // XForwardedGroups string `json:"x-forwarded-groups"` + // XForwardedAccessToken string `json:"x-forwarded-access-token"` + } + + to := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + body, err := json.Marshal( + &respStruct{ + Host: r.Host, + XForwardedHost: r.Header.Get("X-Forwarded-Host"), + // XForwardedEmail: r.Header.Get(`json:"x-forwarded-email"`), + // XForwardedUser: r.Header.Get(`json:"x-forwarded-user"`), + // XForwardedGroups: r.Header.Get(`json:"x-forwarded-groups"`), + // XForwardedAccessToken: r.Header.Get(`json:"x-forwarded-access-token"`), + }, + ) + if err != nil { + t.Fatalf("expected to marshal json: %s", err) + } + rw.Write(body) + })) + defer to.Close() + + toURL, err := url.Parse(to.URL) + if err != nil { + t.Fatalf("expected to parse to url: %s", err) + } + + reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{TLSSkipVerify: false}) + from := httptest.NewServer(reverseProxy) + defer from.Close() + + fromURL, err := url.Parse(from.URL) + if err != nil { + t.Fatalf("expected to parse from url: %s", err) + } + + want := &respStruct{ + Host: toURL.Host, + XForwardedHost: fromURL.Host, + } + + res, err := http.Get(from.URL) + if err != nil { + t.Fatalf("expected to be able to make req: %s", err) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("expected to read body: %s", err) + } + + got := &respStruct{} + err = json.Unmarshal(body, got) + if err != nil { + t.Fatalf("expected to decode json: %s", err) + } + + if !reflect.DeepEqual(want, got) { + t.Logf(" got host: %v", got.Host) + t.Logf("want host: %v", want.Host) + + t.Logf(" got X-Forwarded-Host: %v", got.XForwardedHost) + t.Logf("want X-Forwarded-Host: %v", want.XForwardedHost) + + t.Errorf("got unexpected response for Host or X-Forwarded-Host header") + } + +} + +func TestRoundTrip(t *testing.T) { + testCases := []struct { + name string + url string + expectedError bool + }{ + { + name: "no error", + url: "https://www.example.com/", + }, + { + name: "with error", + url: "/", + expectedError: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tc.url, nil) + ut := upstreamTransport{} + resp, err := ut.RoundTrip(req) + if err == nil && tc.expectedError { + t.Errorf("expected error but error was nil") + } + if err != nil && !tc.expectedError { + t.Errorf("unexpected error %s", err.Error()) + } + if err != nil { + return + } + for key := range securityHeaders { + if resp.Header.Get(key) != "" { + t.Errorf("security header %s expected to be deleted but was %s", key, resp.Header.Get(key)) + } + } + }) + } +} + +func generateTestUpstreamConfigs(to string) []*UpstreamConfig { + if !strings.Contains(to, "://") { + to = fmt.Sprintf("%s://%s", "http", to) + } + parsed, err := url.Parse(to) + if err != nil { + panic(err) + } + templateVars := map[string]string{ + "root_domain": "dev", + "cluster": "sso", + } + upstreamConfigs, err := loadServiceConfigs([]byte(fmt.Sprintf(` +- service: foo + default: + from: foo.sso.dev + to: %s +`, parsed)), "DevProxy", "http", templateVars) + if err != nil { + panic(err) + } + return upstreamConfigs +} + +func TestRobotsTxt(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = generateTestUpstreamConfigs("httpheader.net/") + opts.Validate() + + proxy, err := NewDevProxy(opts) + if err != nil { + t.Errorf("unexpected error %s", err) + } + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "https://foo.sso.dev/robots.txt", nil) + proxy.Handler().ServeHTTP(rw, req) + testutil.Equal(t, 200, rw.Code) + testutil.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) +} + +func TestCerts(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = generateTestUpstreamConfigs("foo-internal.sso.dev") + + requestSigningKey, err := ioutil.ReadFile("testdata/private_key.pem") + testutil.Assert(t, err == nil, "could not read private key from testdata: %s", err) + opts.RequestSigningKey = string(requestSigningKey) + opts.Validate() + + expectedPublicKey, err := ioutil.ReadFile("testdata/public_key.pem") + testutil.Assert(t, err == nil, "could not read public key from testdata: %s", err) + + var keyHash []byte + hasher := sha256.New() + _, _ = hasher.Write(expectedPublicKey) + keyHash = hasher.Sum(keyHash) + + proxy, err := NewDevProxy(opts) + if err != nil { + t.Errorf("unexpected error %s", err) + return + } + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "https://foo.sso.dev/oauth2/v1/certs", nil) + proxy.Handler().ServeHTTP(rw, req) + testutil.Equal(t, 200, rw.Code) + + var certs map[string]string + if err := json.Unmarshal([]byte(rw.Body.String()), &certs); err != nil { + t.Errorf("failed to unmarshal certs from json response: %s", err) + return + } + testutil.Equal(t, string(expectedPublicKey), certs[hex.EncodeToString(keyHash)]) +} + +func TestFavicon(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = generateTestUpstreamConfigs("httpheader.net/") + opts.Validate() + + proxy, _ := NewDevProxy(opts) + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "https://foo.sso.dev/favicon.ico", nil) + proxy.Handler().ServeHTTP(rw, req) + testutil.Equal(t, http.StatusOK, rw.Code) +} + +type SignatureTest struct { + opts *Options + upstream *httptest.Server + upstreamHost string + header http.Header + rw *httptest.ResponseRecorder +} + +func generateSignatureTestUpstreamConfigs(key, to string) []*UpstreamConfig { + + if !strings.Contains(to, "://") { + to = fmt.Sprintf("%s://%s", "http", to) + } + parsed, err := url.Parse(to) + if err != nil { + panic(err) + } + templateVars := map[string]string{ + "root_domain": "dev", + "cluster": "sso", + "foo_signing_key": key, + } + upstreamConfigs, err := loadServiceConfigs([]byte(fmt.Sprintf(` +- service: foo + default: + from: foo.{{cluster}}.{{root_domain}} + to: %s +`, parsed)), "DevProxy", "http", templateVars) + if err != nil { + panic(err) + } + + return upstreamConfigs +} + +func (st *SignatureTest) Close() { + st.upstream.Close() +} + +func TestPing(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("upstream")) + })) + defer upstream.Close() + + opts := NewOptions() + + opts.upstreamConfigs = generateTestUpstreamConfigs(upstream.URL) + opts.Validate() + + proxy, _ := NewDevProxy(opts) + + testCases := []struct { + name string + url string + host string + authenticated bool + expectedCode int + }{ + { + name: "ping never reaches upstream", + url: "http://foo.sso.dev/ping", + authenticated: true, + expectedCode: http.StatusOK, + }, + { + name: "ping skips host check with no host set", + url: "/ping", + expectedCode: http.StatusOK, + }, + { + name: "ping skips host check with unknown host set", + url: "/ping", + host: "example.com", + expectedCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", tc.url, nil) + + proxy.Handler().ServeHTTP(rw, req) + + if tc.expectedCode != rw.Code { + t.Errorf("expected code %d, got %d", tc.expectedCode, rw.Code) + } + if rw.Body.String() != "OK" { + t.Errorf("expected body = %q, got %q", "OK", rw.Body.String()) + } + }) + } +} + +func TestSecurityHeaders(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/add-header": + w.Header().Set("X-Test-Header", "true") + case "/override-security-header": + w.Header().Set("X-Frame-Options", "OVERRIDE") + } + w.WriteHeader(200) + w.Write([]byte(r.URL.RequestURI())) + })) + defer upstream.Close() + + opts := NewOptions() + opts.upstreamConfigs = generateTestUpstreamConfigs(upstream.URL) + opts.Validate() + + proxy, _ := NewDevProxy(opts, testValidatorFunc(true)) + + testCases := []struct { + name string + path string + expectedCode int + expectedHeaders map[string]string + }{ + { + name: "security headers are added to authenticated requests", + path: "/", + expectedCode: http.StatusOK, + expectedHeaders: securityHeaders, + }, + // { + // name: "security headers are added to unauthenticated requests", + // path: "/", + // expectedCode: http.StatusFound, + // expectedHeaders: securityHeaders, + // }, + { + name: "additional headers set by upstream are proxied", + path: "/add-header", + expectedCode: http.StatusOK, + expectedHeaders: map[string]string{ + "X-Test-Header": "true", + }, + }, + { + name: "security headers may NOT be overridden by upstream", + path: "/override-security-header", + expectedCode: http.StatusOK, + expectedHeaders: map[string]string{ + "X-Frame-Options": "SAMEORIGIN", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", fmt.Sprintf("http://foo.sso.dev%s", tc.path), nil) + + proxy.Handler().ServeHTTP(rw, req) + + if tc.expectedCode != rw.Code { + t.Errorf("expected code %d, got %d", tc.expectedCode, rw.Code) + out, _ := json.Marshal(tc) + fmt.Println(string(out)) + } + if tc.expectedCode == http.StatusOK { + if rw.Body.String() != tc.path { + t.Errorf("expected body = %q, got %q", tc.path, rw.Body.String()) + } + } + for key, val := range tc.expectedHeaders { + vals, found := rw.HeaderMap[http.CanonicalHeaderKey(key)] + if !found { + t.Errorf("expected header %s not found", key) + } else if len(vals) > 1 { + t.Errorf("got duplicate values for headers %s: %v", key, vals) + } else if vals[0] != val { + t.Errorf("expected header %s=%q, got %s=%q\n", key, val, key, vals[0]) + } + } + }) + } +} + +func makeUpstreamConfigWithHeaderOverrides(overrides map[string]string) []*UpstreamConfig { + templateVars := map[string]string{ + "root_domain": "dev", + "cluster": "sso", + } + upstreamConfigs, err := loadServiceConfigs([]byte(fmt.Sprintf(` +- service: foo + default: + from: foo.sso.dev + to: httpheader.net/ + +- service: bar + default: + from: bar.sso.dev + to: bar-internal.sso.dev +`)), "DevProxy", "http", templateVars) + if err != nil { + panic(err) + } + upstreamConfigs[0].HeaderOverrides = overrides // we override foo and not bar + return upstreamConfigs +} + +func TestHeaderOverrides(t *testing.T) { + testCases := []struct { + name string + overrides map[string]string + expectedCode int + expectedHeaders map[string]string + }{ + { + name: "security headers are added to requests", + overrides: nil, + expectedCode: http.StatusOK, + expectedHeaders: securityHeaders, + }, + { + name: "security headers are overridden by config", + overrides: map[string]string{ + "X-Frame-Options": "ALLOW-FROM nsa.gov", + }, + expectedCode: http.StatusOK, + expectedHeaders: map[string]string{ + "X-Frame-Options": "ALLOW-FROM nsa.gov", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = makeUpstreamConfigWithHeaderOverrides(tc.overrides) + opts.Validate() + + proxy, _ := NewDevProxy(opts, testValidatorFunc(true)) + + // Check Foo + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://foo.sso.dev/", nil) + proxy.Handler().ServeHTTP(rw, req) + for key, val := range tc.expectedHeaders { + vals, found := rw.HeaderMap[http.CanonicalHeaderKey(key)] + if !found { + t.Errorf("expected header %s not found", key) + } else if len(vals) > 1 { + t.Errorf("got duplicate values for headers %s: %v", key, vals) + } else if vals[0] != val { + t.Errorf("expected header %s=%q, got %s=%q\n", key, val, key, vals[0]) + } + } + + // Check Bar + rwBar := httptest.NewRecorder() + reqBar, _ := http.NewRequest("GET", "http://bar.sso.dev/", nil) + proxy.Handler().ServeHTTP(rwBar, reqBar) + for key, val := range securityHeaders { + vals, found := rwBar.HeaderMap[http.CanonicalHeaderKey(key)] + if !found { + t.Errorf("expected header %s not found", key) + } else if len(vals) > 1 { + t.Errorf("got duplicate values for headers %s: %v", key, vals) + } else if vals[0] != val { + t.Errorf("expected header %s=%q, got %s=%q\n", key, val, key, vals[0]) + } + } + }) + } +} + +// func TestHTTPSRedirect(t *testing.T) { +// upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte(r.URL.String())) +// })) +// defer upstream.Close() + +// testCases := []struct { +// name string +// url string +// host string +// requestHeaders map[string]string +// expectedCode int +// expectedLocation string // must match entire Location header +// expectedLocationHost string // just match hostname of Location header +// expectSTS bool // should we get a Strict-Transport-Security header? +// }{ +// { +// name: "no https redirect with http ", +// url: "http://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: false, +// }, +// { +// name: "no https redirect with https ", +// url: "https://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: false, +// }, +// { +// name: "https redirect ", +// url: "http://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectedLocation: "https://foo.sso.dev/", +// expectSTS: true, +// }, +// { +// name: "https redirect", +// url: "http://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectedLocation: "https://foo.sso.dev/", +// expectSTS: true, +// }, +// { +// name: "no https redirect ", +// url: "https://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: true, +// }, +// { +// name: "no https redirect ", +// url: "https://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: true, +// }, +// { +// name: "request path and query are preserved in redirect", +// url: "http://foo.sso.dev/foo/bar.html?a=1&b=2&c=3", +// expectedCode: http.StatusOK, +// expectedLocation: "https://foo.sso.dev/foo/bar.html?a=1&b=2&c=3", +// expectSTS: true, +// }, +// { +// name: "no https redirect with http and X-Forwarded-Proto=https", +// url: "http://foo.sso.dev/", +// requestHeaders: map[string]string{"X-Forwarded-Proto": "https"}, +// expectedCode: http.StatusOK, +// expectSTS: true, +// }, +// { +// name: "correct host name with relative URL", +// url: "/", +// host: "foo.sso.dev", +// expectedLocation: "https://foo.sso.dev/", +// expectedCode: http.StatusOK, +// expectSTS: true, +// }, +// { +// name: "host validation is applied before https redirect", +// url: "http://bar.sso.dev/", +// expectedCode: statusInvalidHost, +// expectSTS: false, +// }, +// } + +// for _, tc := range testCases { +// t.Run(tc.name, func(t *testing.T) { +// opts := NewOptions() +// opts.upstreamConfigs = generateTestUpstreamConfigs(upstream.URL) +// opts.Validate() + +// proxy, _ := NewDevProxy(opts, testValidatorFunc(true)) + +// rw := httptest.NewRecorder() +// req, _ := http.NewRequest("GET", tc.url, nil) + +// for key, val := range tc.requestHeaders { +// req.Header.Set(key, val) +// } + +// if tc.host != "" { +// req.Host = tc.host +// } + +// proxy.Handler().ServeHTTP(rw, req) + +// if tc.expectedCode != rw.Code { +// t.Errorf("expected code %d, got %d", tc.expectedCode, rw.Code) +// } + +// location := rw.Header().Get("Location") +// locationURL, err := url.Parse(location) +// if err != nil { +// t.Errorf("error parsing location %q: %s", location, err) +// } +// if tc.expectedLocation != "" && location != tc.expectedLocation { +// t.Errorf("expected Location=%q, got Location=%q", tc.expectedLocation, location) +// } +// if tc.expectedLocationHost != "" && locationURL.Hostname() != tc.expectedLocationHost { +// t.Errorf("expected location host = %q, got %q", tc.expectedLocationHost, locationURL.Hostname()) +// } + +// stsKey := http.CanonicalHeaderKey("Strict-Transport-Security") +// if tc.expectSTS { +// val := rw.Header().Get(stsKey) +// expectedVal := "max-age=31536000" +// if val != expectedVal { +// t.Errorf("expected %s=%q, got %q", stsKey, expectedVal, val) +// } +// } else { +// _, found := rw.HeaderMap[stsKey] +// if found { +// t.Errorf("%s header should not be present, got %q", stsKey, rw.Header().Get(stsKey)) +// } +// } +// }) +// } +// } + +func TestTimeoutHandler(t *testing.T) { + testCases := []struct { + name string + config *UpstreamConfig + defaultTimeout time.Duration + globalTimeout time.Duration + ExpectedStatusCode int + ExpectedBody string + ExpectedErr error + }{ + { + name: "does not timeout", + config: &UpstreamConfig{ + Timeout: time.Duration(100) * time.Millisecond, + }, + defaultTimeout: time.Duration(100) * time.Millisecond, + globalTimeout: time.Duration(100) * time.Millisecond, + ExpectedStatusCode: 200, + ExpectedBody: "OK", + }, + { + name: "times out using upstream config timeout", + config: &UpstreamConfig{ + Service: "service-test", + Timeout: time.Duration(10) * time.Millisecond, + }, + defaultTimeout: time.Duration(100) * time.Millisecond, + globalTimeout: time.Duration(100) * time.Millisecond, + ExpectedStatusCode: 503, + ExpectedBody: fmt.Sprintf("service-test failed to respond within the 10ms timeout period"), + }, + { + name: "times out using default upstream config timeout", + config: &UpstreamConfig{ + Service: "service-test", + }, + defaultTimeout: time.Duration(10) * time.Millisecond, + globalTimeout: time.Duration(100) * time.Millisecond, + ExpectedStatusCode: 503, + ExpectedBody: fmt.Sprintf("service-test failed to respond within the 10ms timeout period"), + }, + { + name: "times out using global write timeout", + config: &UpstreamConfig{ + Service: "service-test", + }, + defaultTimeout: time.Duration(100) * time.Millisecond, + globalTimeout: time.Duration(10) * time.Millisecond, + ExpectedErr: &url.Error{ + Err: io.EOF, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := NewOptions() + opts.DefaultUpstreamTimeout = tc.defaultTimeout + + baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timer := time.NewTimer(time.Duration(50) * time.Millisecond) + <-timer.C + w.Write([]byte("OK")) + }) + timeoutHandler := NewTimeoutHandler(baseHandler, opts, tc.config) + + srv := httptest.NewUnstartedServer(timeoutHandler) + srv.Config.WriteTimeout = tc.globalTimeout + srv.Start() + defer srv.Close() + + res, err := http.Get(srv.URL) + if err != nil { + if tc.ExpectedErr == nil { + t.Fatalf("got unexpected err=%v", err) + } + urlErr, ok := err.(*url.Error) + if !ok { + t.Fatalf("got unexpected err=%v", err) + } + if urlErr.Err != io.EOF { + t.Fatalf("got unexpected err=%v", err) + } + // We got the error we expected, exit + return + } + + if res.StatusCode != tc.ExpectedStatusCode { + t.Errorf(" got=%v", res.StatusCode) + t.Errorf("want=%v", tc.ExpectedStatusCode) + t.Fatalf("got unexpected status code") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("got unexpected err=%q", err) + } + + if string(body) != tc.ExpectedBody { + t.Errorf(" got=%q", body) + t.Errorf("want=%q", tc.ExpectedBody) + t.Fatalf("got unexpected body") + } + }) + } +} + +func generateTestRewriteUpstreamConfigs(fromRegex, toTemplate string) []*UpstreamConfig { + templateVars := map[string]string{ + "root_domain": "dev", + "cluster": "sso", + } + upstreamConfigs, err := loadServiceConfigs([]byte(fmt.Sprintf(` +- service: foo + default: + from: %s + to: %s + type: rewrite +`, fromRegex, toTemplate)), "DevProxy", "http", templateVars) + if err != nil { + panic(err) + } + return upstreamConfigs +} + +func TestRewriteRoutingHandling(t *testing.T) { + type response struct { + Host string `json:"host"` + XForwardedHost string `json:"x-forwarded-host"` + XForwardedEmail string `json:"x-forwarded-email"` + XForwardedUser string `json:"x-forwarded-user"` + XForwardedGroups string `json:"x-forwarded-groups"` + XForwardedAccessToken string `json:"x-forwarded-access-token"` + } + + upstream := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + body, err := json.Marshal( + &response{ + Host: r.Host, + XForwardedHost: r.Header.Get("X-Forwarded-Host"), + XForwardedEmail: r.Header.Get(`json:"x-forwarded-email"`), + XForwardedUser: r.Header.Get(`json:"x-forwarded-user"`), + XForwardedGroups: r.Header.Get(`json:"x-forwarded-groups"`), + XForwardedAccessToken: r.Header.Get(`json:"x-forwarded-access-token"`), + }, + ) + if err != nil { + t.Fatalf("expected to marshal json: %s", err) + } + rw.Write(body) + })) + defer upstream.Close() + + parsedUpstreamURL, err := url.Parse(upstream.URL) + if err != nil { + t.Fatalf("expected to parse upstream URL err:%q", err) + } + + upstreamHost, upstreamPort, err := net.SplitHostPort(parsedUpstreamURL.Host) + if err != nil { + t.Fatalf("expected to split host/port err:%q", err) + } + + testCases := []struct { + Name string + TestHost string + TestUser string + TestEmail string + TestGroups string + TestAccessToken string + FromRegex string + ToTemplate string + ExpectedCode int + ExpectedResponse *response + }{ + { + Name: "everything should work in the normal case", + TestHost: "foo.sso.dev", + FromRegex: "(.*)", + ToTemplate: parsedUpstreamURL.Host, + ExpectedCode: http.StatusOK, + ExpectedResponse: &response{ + Host: parsedUpstreamURL.Host, + XForwardedHost: "foo.sso.dev", + }, + }, + { + Name: "it should not match a non-matching regex", + TestHost: "foo.sso.dev", + FromRegex: "bar", + ToTemplate: parsedUpstreamURL.Host, + ExpectedCode: statusInvalidHost, + }, + { + Name: "it should match and replace using regex/template to find port in embedded domain", + TestHost: fmt.Sprintf("somedomain--%s", upstreamPort), + FromRegex: "somedomain--(.*)", // capture port + ToTemplate: fmt.Sprintf("%s:$1", upstreamHost), // add port to dest + ExpectedCode: http.StatusOK, + ExpectedResponse: &response{ + Host: parsedUpstreamURL.Host, + XForwardedHost: fmt.Sprintf("somedomain--%s", upstreamPort), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + opts := NewOptions() + opts.upstreamConfigs = generateTestRewriteUpstreamConfigs(tc.FromRegex, tc.ToTemplate) + opts.Validate() + proxy, err := NewDevProxy(opts, testValidatorFunc(true)) + if err != nil { + t.Fatalf("unexpected err provisioning dev proxy err:%q", err) + } + + req, err := http.NewRequest("GET", fmt.Sprintf("https://%s/", tc.TestHost), strings.NewReader("")) + if err != nil { + t.Fatalf("unexpected err creating request err:%s", err) + } + + rw := httptest.NewRecorder() + + proxy.Handler().ServeHTTP(rw, req) + + if tc.ExpectedCode != rw.Code { + t.Errorf("expected code %d, got %d", tc.ExpectedCode, rw.Code) + } + + if tc.ExpectedResponse == nil { + // we've passed our test, we didn't expect a body, exit early + return + } + + body, err := ioutil.ReadAll(rw.Body) + if err != nil { + t.Fatalf("expected to read body: %s", err) + } + + got := &response{} + err = json.Unmarshal(body, got) + if err != nil { + t.Fatalf("expected to decode json: %s", err) + } + + if !reflect.DeepEqual(tc.ExpectedResponse, got) { + t.Logf(" got host: %v", got.Host) + t.Logf("want host: %v", tc.ExpectedResponse.Host) + + t.Logf(" got X-Forwarded-Host: %v", got.XForwardedHost) + t.Logf("want X-Forwarded-Host: %v", tc.ExpectedResponse.XForwardedHost) + + t.Errorf("got unexpected response for Host or X-Forwarded-Host header") + } + }) + } +} diff --git a/internal/devproxy/logging_handler.go b/internal/devproxy/logging_handler.go new file mode 100644 index 00000000..f11825d7 --- /dev/null +++ b/internal/devproxy/logging_handler.go @@ -0,0 +1,112 @@ +// largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go +// to add logging of request duration as last value (and drop referrer) + +package devproxy + +import ( + "io" + "net/http" + "net/url" + "strings" + "time" + + log "github.com/buzzfeed/sso/internal/pkg/logging" +) + +// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status +// code and body size +type responseLogger struct { + w http.ResponseWriter + status int + size int +} + +func (l *responseLogger) Header() http.Header { + return l.w.Header() +} + +func (l *responseLogger) Write(b []byte) (int, error) { + if l.status == 0 { + // The status will be StatusOK if WriteHeader has not been called yet + l.status = http.StatusOK + } + + size, err := l.w.Write(b) + l.size += size + return size, err +} + +func (l *responseLogger) WriteHeader(s int) { + l.w.WriteHeader(s) + l.status = s +} + +func (l *responseLogger) Status() int { + return l.status +} + +func (l *responseLogger) Size() int { + return l.size +} + +func (l *responseLogger) Flush() { + f := l.w.(http.Flusher) + f.Flush() +} + +// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends +type loggingHandler struct { + writer io.Writer + handler http.Handler + enabled bool +} + +// NewLoggingHandler returns a new loggingHandler that wraps a handler, and writer. +func NewLoggingHandler(out io.Writer, h http.Handler, v bool) http.Handler { + return loggingHandler{writer: out, + handler: h, + enabled: v, + } +} + +func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + t := time.Now() + url := *req.URL + logger := &responseLogger{w: w} + h.handler.ServeHTTP(logger, req) + if !h.enabled { + return + } + requestDuration := time.Now().Sub(t) + logRequest(req, url, requestDuration, logger.Status()) +} + +// logRequests creates a log message from the request status, method, url, proxy host and duration of the request +func logRequest(req *http.Request, url url.URL, requestDuration time.Duration, status int) { + // Convert duration to floating point milliseconds + // https://github.com/golang/go/issues/5491#issuecomment-66079585 + durationMS := requestDuration.Seconds() * 1e3 + + logger := log.NewLogEntry() + logger.WithHTTPStatus(status).WithRequestMethod(req.Method).WithRequestURI( + url.RequestURI()).WithUserAgent( + req.Header.Get("User-Agent")).WithRemoteAddress( + getRemoteAddr(req)).WithRequestDurationMs( + durationMS).WithAction(GetActionTag(req)).Info() +} + +// getRemoteAddr returns the client IP address from a request. If present, the +// X-Forwarded-For header is assumed to be set by a load balancer, and its +// rightmost entry (the client IP that connected to the LB) is returned. +func getRemoteAddr(req *http.Request) string { + addr := req.RemoteAddr + forwardedHeader := req.Header.Get("X-Forwarded-For") + if forwardedHeader != "" { + forwardedList := strings.Split(forwardedHeader, ",") + forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1]) + if forwardedAddr != "" { + addr = forwardedAddr + } + } + return addr +} diff --git a/internal/devproxy/metrics.go b/internal/devproxy/metrics.go new file mode 100644 index 00000000..228c9b1c --- /dev/null +++ b/internal/devproxy/metrics.go @@ -0,0 +1,21 @@ +package devproxy + +import ( + "net/http" +) + +// GetActionTag returns the action triggered by an http.Request . +func GetActionTag(req *http.Request) string { + // only log metrics for these paths and actions + pathToAction := map[string]string{ + "/favicon.ico": "favicon", + "/ping": "ping", + "/robots.txt": "robots", + } + // get the action from the url path + path := req.URL.Path + if action, ok := pathToAction[path]; ok { + return action + } + return "proxy" +} diff --git a/internal/devproxy/middleware.go b/internal/devproxy/middleware.go new file mode 100644 index 00000000..622e7605 --- /dev/null +++ b/internal/devproxy/middleware.go @@ -0,0 +1,67 @@ +package devproxy + +import ( + "net/http" + "net/url" +) + +// With inspiration from https://github.com/unrolled/secure +var securityHeaders = map[string]string{ + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "SAMEORIGIN", + "X-XSS-Protection": "1; mode=block", +} + +// setHeaders ensures that every response includes some basic security headers. +// +// Note: the Strict-Transport-Security header is set by the requireHTTPS +// middleware below, to avoid issues with development environments that must +// allow plain HTTP. +func setSecurityHeaders(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + for key, val := range securityHeaders { + rw.Header().Set(key, val) + } + h.ServeHTTP(rw, req) + }) +} + +func (p *DevProxy) setResponseHeaderOverrides(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + route, ok := p.router(req) + if ok && route.upstreamConfig.HeaderOverrides != nil { + for key, val := range route.upstreamConfig.HeaderOverrides { + rw.Header().Set(key, val) + } + } + h.ServeHTTP(rw, req) + }) +} + +// validateHost ensures that each request's host is valid +func (p *DevProxy) validateHost(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if _, ok := p.router(req); !ok { + p.UnknownHost(rw, req) + return + } + h.ServeHTTP(rw, req) + }) +} + +func requireHTTPS(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("Strict-Transport-Security", "max-age=31536000") + if req.URL.Scheme != "https" && req.Header.Get("X-Forwarded-Proto") != "https" { + dest := &url.URL{ + Scheme: "https", + Host: req.Host, + Path: req.URL.Path, + RawQuery: req.URL.RawQuery, + } + http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently) + return + } + h.ServeHTTP(rw, req) + }) +} diff --git a/internal/devproxy/options.go b/internal/devproxy/options.go new file mode 100644 index 00000000..00c5c367 --- /dev/null +++ b/internal/devproxy/options.go @@ -0,0 +1,103 @@ +package devproxy + +import ( + "fmt" + "io/ioutil" + "net/url" + "os" + "strings" + "time" +) + +// Options are configuration options that can be set by Environment Variables +// Port - int - port to listen on for HTTP clients +// UpstreamConfigsFile - the path to upstream configs file +// Cluster - the cluster in which this is running, used for upstream configs +// Scheme - the default scheme, used for upstream configs +// DefaultUpstreamTimeout - the default time period to wait for a response from an upstream +// TCPWriteTimeout - http server tcp write timeout +// TCPReadTimeout - http server tcp read timeout +type Options struct { + Port int `envconfig:"PORT" default:"4180"` + UpstreamConfigsFile string `envconfig:"UPSTREAM_CONFIGS"` + Scheme string `envconfig:"SCHEME" default:"https"` + Host string `envconfig:"HOST"` + DefaultUpstreamTimeout time.Duration `envconfig:"DEFAULT_UPSTREAM_TIMEOUT" default:"10s"` + TCPWriteTimeout time.Duration `envconfig:"TCP_WRITE_TIMEOUT" default:"30s"` + TCPReadTimeout time.Duration `envconfig:"TCP_READ_TIMEOUT" default:"30s"` + RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"` + RequestSigningKey string `envconfig:"REQUEST_SIGNATURE_KEY"` + // This is an override for supplying template vars at test time + testTemplateVars map[string]string + // internal values that are set after config validation + upstreamConfigs []*UpstreamConfig +} + +// NewOptions returns a new options struct +func NewOptions() *Options { + return &Options{ + RequestLogging: true, + DefaultUpstreamTimeout: time.Duration(1) * time.Second, + } +} + +func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) { + parsed, err := url.Parse(toParse) + if err != nil { + return nil, append(msgs, fmt.Sprintf( + "error parsing %s-url=%q %s", urltype, toParse, err)) + } + return parsed, msgs +} + +// Validate validates options +func (o *Options) Validate() error { + msgs := make([]string, 0) + + if o.UpstreamConfigsFile == "" { + msgs = append(msgs, "missing setting: upstream-configs") + o.UpstreamConfigsFile = "internal/devproxy/testdata/upstream_configs.yml" + } + + if o.UpstreamConfigsFile != "" { + rawBytes, err := ioutil.ReadFile(o.UpstreamConfigsFile) + if err != nil { + msgs = append(msgs, fmt.Sprintf("error reading upstream configs file: %s", err)) + } + + templateVars := parseEnvironment(os.Environ()) + if o.testTemplateVars != nil { + templateVars = o.testTemplateVars + } + + o.upstreamConfigs, err = loadServiceConfigs(rawBytes, "default", o.Scheme, templateVars) + if err != nil { + msgs = append(msgs, fmt.Sprintf("error parsing upstream configs file %s", err)) + } + } + + if len(msgs) != 0 { + return fmt.Errorf("Invalid configuration:\n %s", + strings.Join(msgs, "\n ")) + } + return nil +} + +func parseEnvironment(environ []string) map[string]string { + envPrefix := "DEV_CONFIG_" + env := make(map[string]string) + if len(environ) == 0 { + return env + } + for _, e := range environ { + // we only include env keys that have the SSO_CONFIG_ prefix + if !strings.HasPrefix(e, envPrefix) { + continue + } + + split := strings.SplitN(e, "=", 2) + key := strings.ToLower(strings.TrimPrefix(split[0], envPrefix)) + env[key] = split[1] + } + return env +} diff --git a/internal/devproxy/request_signer.go b/internal/devproxy/request_signer.go new file mode 100644 index 00000000..7184a35e --- /dev/null +++ b/internal/devproxy/request_signer.go @@ -0,0 +1,202 @@ +package devproxy + +import ( + "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "hash" + "io/ioutil" + "net/http" + "strings" +) + +// Only headers enumerated in this list are used to compute the signature of a request. +var signedHeaders = []string{ + "Content-Length", + "Content-Md5", + "Content-Type", + "Date", + "Authorization", + "X-Forwarded-User", + "X-Forwarded-Email", + "X-Forwarded-Groups", + "Cookie", +} + +// Name of the header used to transmit the signature computed for the request. +var signatureHeader = "Sso-Signature" +var signingKeyHeader = "kid" + +// RequestSigner exposes an interface for digitally signing requests using an RSA private key. +// See comments for the Sign() method below, for more on how this signature is constructed. +type RequestSigner struct { + hasher hash.Hash + signingKey crypto.Signer + publicKeyStr string + publicKeyID string +} + +// NewRequestSigner constructs a RequestSigner object from a PEM+PKCS8 encoded RSA public key. +func NewRequestSigner(signingKeyPemStr string) (*RequestSigner, error) { + var privateKey crypto.Signer + var publicKeyPEM []byte + + // Strip PEM encoding from private key. + block, _ := pem.Decode([]byte(signingKeyPemStr)) + if block == nil { + return nil, fmt.Errorf("could not read PEM block from signing key") + } + + // Extract private key as a crypto.Signer object. + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("could not read key from signing key bytes: %s", err) + } + privateKey = key.(crypto.Signer) + + // Derive public key. + rsaPublicKey, ok := privateKey.Public().(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("only RSA public keys are currently supported") + } + publicKeyPEM = pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: x509.MarshalPKCS1PublicKey(rsaPublicKey), + }) + + var keyHash []byte + hasher := sha256.New() + _, _ = hasher.Write(publicKeyPEM) + keyHash = hasher.Sum(keyHash) + + return &RequestSigner{ + hasher: sha256.New(), + signingKey: privateKey, + publicKeyStr: string(publicKeyPEM), + publicKeyID: hex.EncodeToString(keyHash), + }, nil +} + +// mapRequestToHashInput returns a string representation of a Request, formatted as a +// newline-separated sequence of entries from the request. Any two Requests sharing the same +// representation are considered "equivalent" for purposes of verifying the integrity of a request. +// +// Representations are formatted as follows: +// +// ... +// +// +// +// where: +// is the ','-joined concatenation of all header values of `signedHeaders[k]`; empty +// values such as '' and all other headers in the request are ignored, +// is the string "(?)(#FRAGMENT)", where "?" and "#" are +// ommitted if the associated components are absent from the request URL, +// is the body of the Request (may be `nil`; e.g. for GET requests). +// +// Receiving endpoints authenticating the integrity of a request should reconstruct this document +// exactly, when verifying the contents of a received request. +func mapRequestToHashInput(req *http.Request) (string, error) { + entries := []string{} + + // Add signed headers. + for _, hdr := range signedHeaders { + hdrValues := removeEmpty(req.Header[hdr]) + if len(hdrValues) > 0 { + entries = append(entries, strings.Join(hdrValues, ",")) + } + } + + // Add canonical URL representation. Ignore URL {scheme, host, port, etc}. + entries = append(entries, func() string { + url := req.URL.Path + if len(req.URL.RawQuery) > 0 { + url += ("?" + req.URL.RawQuery) + } + if len(req.URL.Fragment) > 0 { + url += ("#" + req.URL.Fragment) + } + return url + }()) + + // Add request body, if present (may be absent for GET requests, etc). + if req.Body != nil { + body, _ := ioutil.ReadAll(req.Body) + req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + entries = append(entries, string(body)) + } + + // Return the join of all entries, with each separated by a newline. + return strings.Join(entries, "\n"), nil +} + +// Sign appends a header to the request, with a public-key encrypted signature derive from +// a subset of the request headers, together with the request URL and body. +// +// Signature is computed as: +// repr := Representation(request) <- Computed by mapRequestToHashInput() +// hash := SHA256(repr) +// sig := SIGN(hash, SigningKey) +// final := WEB_SAFE_BASE64(sig) +// The header `Sso-Signature` is given the value of `final`. +// +// Receiving endpoints authenticating the integrity of a request should: +// 1. Strip the WEB_SAFE_BASE64 encoding from the value of `signatureHeader`, +// 2. Decrypt the resulting value using the public key published by sso_proxy, thus obtaining the +// hash of the request representation, +// 3. Compute the request representation from the received request, using the same format as the +// mapRequestToHashInput() function above, +// 4. Apply SHA256 hash to the recomputed representation, and verify that it matches the decrypted +// hash value received through the `Sso-Signature` of the request. +// +// Any requests failing this check should be considered tampered with, and rejected. +func (signer RequestSigner) Sign(req *http.Request) error { + // Generate the request representation that will serve as hash input. + repr, err := mapRequestToHashInput(req) + if err != nil { + return fmt.Errorf("could not generate representation for request: %s", err) + } + + // Generate hash of the document buffer. + var documentHash []byte + signer.hasher.Reset() + _, _ = signer.hasher.Write([]byte(repr)) + documentHash = signer.hasher.Sum(documentHash) + + // Sign the documentHash with the signing key. + signatureBytes, err := signer.signingKey.Sign(rand.Reader, documentHash, crypto.SHA256) + if err != nil { + return fmt.Errorf("failed signing document hash with signing key: %s", err) + } + signature := base64.URLEncoding.EncodeToString(signatureBytes) + + // Set the signature and signing-key request headers. Return nil to indicate no error. + req.Header.Set(signatureHeader, signature) + req.Header.Set(signingKeyHeader, signer.publicKeyID) + return nil +} + +// PublicKey returns a pair (KeyID, Key), where: +// - KeyID is a unique identifier (currently the SHA256 hash of Key), +// - Key is the (PEM+PKCS1)-encoding of a public key, usable for validating signed requests. +func (signer RequestSigner) PublicKey() (string, string) { + return signer.publicKeyID, signer.publicKeyStr +} + +func removeEmpty(s []string) []string { + r := []string{} + for _, str := range s { + if len(str) > 0 { + r = append(r, str) + } + } + return r +} diff --git a/internal/devproxy/request_signer_test.go b/internal/devproxy/request_signer_test.go new file mode 100644 index 00000000..ced5b190 --- /dev/null +++ b/internal/devproxy/request_signer_test.go @@ -0,0 +1,179 @@ +package devproxy + +import ( + "crypto" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/buzzfeed/sso/internal/pkg/testutil" +) + +// Convenience variables and utilities. +var urlExample = "https://foo.sso.example.com/path" + +func addHeaders(req *http.Request, examples []string, extras map[string][]string) { + var signedHeaderExamples = map[string][]string{ + "Content-Length": {"1234"}, + "Content-Md5": {"F00D"}, + "Content-Type": {"application/json"}, + "Date": {"2018-11-08"}, + "Authorization": {"Bearer ab12cd34"}, + "X-Forwarded-User": {"octoboi"}, + "X-Forwarded-Email": {"octoboi@example.com"}, + "X-Forwarded-Groups": {"molluscs", "security_applications"}, + } + + for _, signedHdr := range examples { + for _, value := range signedHeaderExamples[signedHdr] { + req.Header.Add(signedHdr, value) + } + } + for extraHdr, values := range extras { + for _, value := range values { + req.Header.Add(extraHdr, value) + } + } +} + +func TestRepr_UrlRepresentation(t *testing.T) { + testURL := func(url string, expect string) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Errorf("could not build request: %s", err) + } + + repr, err := mapRequestToHashInput(req) + if err != nil { + t.Errorf("could not map request to hash input: %s", err) + } + testutil.Equal(t, expect, repr) + } + + testURL("http://foo.sso.example.com/path/to/resource", "/path/to/resource") + testURL("http://foo.sso.example.com/path?", "/path") + testURL("http://foo.sso.example.com/path/to?query#fragment", "/path/to?query#fragment") + testURL("https://foo.sso.example.com:4321/path#fragment", "/path#fragment") + testURL("http://foo.sso.example.com/path?query¶m=value#", "/path?query¶m=value") +} + +func TestRepr_HeaderRepresentation(t *testing.T) { + testHeaders := func(include []string, extra map[string][]string, expect string) { + req, err := http.NewRequest("GET", urlExample, nil) + if err != nil { + t.Errorf("could not build request: %s", err) + } + addHeaders(req, include, extra) + repr, err := mapRequestToHashInput(req) + if err != nil { + t.Errorf("could not map request to hash input: %s", err) + } + testutil.Equal(t, expect, repr) + } + + // Partial set of signed headers. + testHeaders([]string{"Authorization", "X-Forwarded-Groups"}, nil, + "Bearer ab12cd34\n"+ + "molluscs,security_applications\n"+ + "/path") + + // Full set of signed headers. + testHeaders(signedHeaders, nil, + "1234\n"+ + "F00D\n"+ + "application/json\n"+ + "2018-11-08\n"+ + "Bearer ab12cd34\n"+ + "octoboi\n"+ + "octoboi@example.com\n"+ + "molluscs,security_applications\n"+ + "/path") + + // Partial set of signed headers, plus another header (should not appear in representation). + testHeaders([]string{"Authorization", "X-Forwarded-Email"}, + map[string][]string{"X-Octopus-Stuff": {"54321"}}, + "Bearer ab12cd34\n"+ + "octoboi@example.com\n"+ + "/path") + + // Only unsigned headers. + testHeaders(nil, map[string][]string{"X-Octopus-Stuff": {"83721"}}, "/path") +} + +func TestRepr_PostWithBody(t *testing.T) { + req, err := http.NewRequest("POST", urlExample, strings.NewReader("something\nor other")) + if err != nil { + t.Errorf("could not build request: %s", err) + } + addHeaders(req, []string{"X-Forwarded-Email", "X-Forwarded-Groups"}, + map[string][]string{"X-Octopus-Stuff": {"54321"}}) + + repr, err := mapRequestToHashInput(req) + if err != nil { + t.Errorf("could not map request to hash input: %s", err) + } + testutil.Equal(t, + "octoboi@example.com\n"+ + "molluscs,security_applications\n"+ + "/path\n"+ + "something\n"+ + "or other", + repr) +} + +func TestSignatureRoundTripDecoding(t *testing.T) { + // Keys used for signing/validating request. + privateKey, err := ioutil.ReadFile("testdata/private_key.pem") + testutil.Assert(t, err == nil, "error reading private key from testdata") + + publicKey, err := ioutil.ReadFile("testdata/public_key.pem") + testutil.Assert(t, err == nil, "error reading public key from testdata") + + // Build the RequestSigner object used to generate the request signature header. + requestSigner, err := NewRequestSigner(string(privateKey)) + testutil.Assert(t, err == nil, "could not initialize request signer: %s", err) + + // And build the rsa.PublicKey object that will help verify the signature. + verifierKey, err := func() (*rsa.PublicKey, error) { + if block, _ := pem.Decode(publicKey); block == nil { + return nil, fmt.Errorf("could not read PEM block from public key") + } else if key, err := x509.ParsePKCS1PublicKey(block.Bytes); err != nil { + return nil, fmt.Errorf("could not read key from public key bytes: %s", err) + } else { + return key, nil + } + }() + testutil.Assert(t, err == nil, "could not construct public key: %s", err) + + // Build the Request to be signed. + req, err := http.NewRequest("POST", urlExample, strings.NewReader("something\nor other")) + testutil.Assert(t, err == nil, "could not construct request: %s", err) + addHeaders(req, []string{"X-Forwarded-Email", "X-Forwarded-Groups"}, + map[string][]string{"X-Octopus-Stuff": {"54321"}}) + + // Sign the request, and extract its signature from the header. + err = requestSigner.Sign(req) + testutil.Assert(t, err == nil, "could not sign request: %s", err) + sig, _ := base64.URLEncoding.DecodeString(req.Header.Get("Sso-Signature")) + + // Hardcoded expected hash, computed from the request. + expectedHash, _ := hex.DecodeString( + "04158c00fbecccd8b5dca58634a0a7f28bf5ad908f19cb1b404bdd37bb4485a9") + err = rsa.VerifyPKCS1v15(verifierKey, crypto.SHA256, expectedHash, sig) + testutil.Assert(t, err == nil, "could not verify request signature: %s", err) + + // Verify that the signing-key header is the hash of the public-key. + var pubKeyHash []byte + hasher := sha256.New() + _, _ = hasher.Write(publicKey) + pubKeyHash = hasher.Sum(pubKeyHash) + testutil.Equal(t, hex.EncodeToString(pubKeyHash), req.Header.Get("kid")) +} diff --git a/internal/devproxy/templates.go b/internal/devproxy/templates.go new file mode 100644 index 00000000..3b696b53 --- /dev/null +++ b/internal/devproxy/templates.go @@ -0,0 +1,126 @@ +package devproxy + +import ( + "html/template" +) + +func getTemplates() *template.Template { + t := template.New("foo") + t = template.Must(t.Parse(`{{define "error.html"}} + + + + Error + + + + + +
+
+
+

{{.Title}}

+
+

+ {{.Message}}
+ HTTP {{.Code}} +

+ {{if ne .Code 403 }} +
+ +
+ {{end}} +
+ +
+ +{{end}}`)) + return t +} diff --git a/internal/devproxy/testdata/.env b/internal/devproxy/testdata/.env new file mode 100644 index 00000000..9e9b4a15 --- /dev/null +++ b/internal/devproxy/testdata/.env @@ -0,0 +1,11 @@ +export PORT=4888 +export UPSTREAM_CONFIGS=/path/to/upstream_configs.yml +export SCHEME=http +export HOST=http://localhost/ +export CLUSTER=sso-dev +export DEFAULT_UPSTREAM_TIMEOUT=10s +export TCP_WRITE_TIMEOUT=30s +export TCP_READ_TIMEOUT=30s +export REQUEST_LOGGING=true +export REQUEST_SIGNATURE_KEY=$(cat /path/to/private_key.pem) +export DEV_CONFIG_DEVSHIM_SIGNING_KEY="sha256:shared-secret-value" diff --git a/internal/devproxy/testdata/private_key.pem b/internal/devproxy/testdata/private_key.pem new file mode 100644 index 00000000..03a16bd3 --- /dev/null +++ b/internal/devproxy/testdata/private_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCy38IQCH8QyeNF +s1zA0XuIyqnTcSfYZg0nPfB+K//pFy7tIOAwmR6th8NykrxFhEQDHKNCmLXt4j8V +FDHQZtGjUBHRmAXZW8NOQ0EI1vc/Dpt09sU40JQlXZZeL+9/7iAxEfSE3TQr1k7P +Xwxpjm9rsLSn7FoLnvXco0mc6+d2jjxf4cMgJIaQLKOd783KUQzLVEvBQJ05JnpI +2xMjS0q33ltMTMGF3QZQN9i4bZKgnItomKxTJbfxftO11FTNLB7og94sWmlThAY5 +/UMjZaWYJ1g89+WUJ+KpVYyJsHPBBkaQG+NYazcLDyIowpzJ1WVkInysshpTqwT+ +UPV4at+jAgMBAAECggEAX8lxK5LRMJVcLlwRZHQJekRE0yS6WKi1jHkfywEW5qRy +jatYQs4MXpLgN/+Z8IQWw6/XQXdznTLV4xzQXDBjPNhI4ntNTotUOBnNvsUW296f +ou/uxzDy1FuchU2YLGLBPGXIEko+gOcfhu74P6J1yi5zX6UyxxxVvtR2PCEb7yDw +m2881chwMblZ5Z8uyF++ajkK3/rqLk64w29+K4ZTDbTcCp5NtBYx2qSEU7yp12rc +qscUGqxG00Abx+osI3cUn0kOq7356LeR1rfA15yZwOb+s28QYp2WPlVB2hOiYXQv ++ttEOpt0x1QJhBAsFgwY173sD5w2MryRQb1RCwBvqQKBgQDeTdbRzxzAl83h/mAq +5I+pNEz57veAFVO+iby7TbZ/0w6q+QeT+bHF+TjGHiSlbtg3nd9NPrex2UjiN7ej ++DrxhsSLsP1ZfwDNv6f1Ii1HluJclUFSUNU/LntBjqqCJ959lniNp1y5+ZQ/j2Rf ++ZraVsHRB0itilFeAl5+n7CfxwKBgQDN/K+E1TCbp1inU60Lc9zeb8fqTEP6Mp36 +qQ0Dp+KMLPJ0xQSXFq9ILr4hTJlBqfmTkfmQUcQuwercZ3LNQPbsuIg96bPW73R1 +toXjokd6jUn5sJXCOE0RDumcJrL1VRf9RN1AmM4CgCc/adUMjws3pBc5R4An7UyU +ouRQhN+5RQKBgFOVTrzqM3RSX22mWAAomb9T09FxQQueeTM91IFUMdcTwwMTyP6h +Nm8qSmdrM/ojmBYpPKlteGHdQaMUse5rybXAJywiqs84ilPRyNPJOt8c4xVOZRYP +IG62Ck/W1VNErEnqBn+0OpAOP+g6ANJ5JfkL/6mZJIFjbT58g4z2e9FHAoGBAM3f +uBkd7lgTuLJ8Gh6xLVYQCJHuqZ49ytFE9qHpwK5zGdyFMSJE5OlS9mpXoXEUjkHk +iraoUlidLbwdlIr6XBCaGmku07SFXTNtOoIZpjEhV4c762HTXYsoCWos733uD2zt +z+iJEJVFOnTRtMK5kO+KjD+Oa9L8BCcmauTi+Ku1AoGAZBUzi95THA60hPXI0hm/ +o0J5mfLkFPfhpUmDAMaEpv3bM4byA+IGXSZVc1IZO6cGoaeUHD2Yl1m9a5tv5rF+ +FS9Ht+IgATvGojah+xxQy+kf6tRB9Hn4scyq+64AesXlDbWDEagomQ0hyV/JKSS6 +LQatvnCmBd9omRT2uwYUo+o= +-----END PRIVATE KEY----- diff --git a/internal/devproxy/testdata/public_key.pem b/internal/devproxy/testdata/public_key.pem new file mode 100644 index 00000000..cccac43b --- /dev/null +++ b/internal/devproxy/testdata/public_key.pem @@ -0,0 +1,8 @@ +-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEAst/CEAh/EMnjRbNcwNF7iMqp03En2GYNJz3wfiv/6Rcu7SDgMJke +rYfDcpK8RYREAxyjQpi17eI/FRQx0GbRo1AR0ZgF2VvDTkNBCNb3Pw6bdPbFONCU +JV2WXi/vf+4gMRH0hN00K9ZOz18MaY5va7C0p+xaC5713KNJnOvndo48X+HDICSG +kCyjne/NylEMy1RLwUCdOSZ6SNsTI0tKt95bTEzBhd0GUDfYuG2SoJyLaJisUyW3 +8X7TtdRUzSwe6IPeLFppU4QGOf1DI2WlmCdYPPfllCfiqVWMibBzwQZGkBvjWGs3 +Cw8iKMKcydVlZCJ8rLIaU6sE/lD1eGrfowIDAQAB +-----END RSA PUBLIC KEY----- diff --git a/internal/devproxy/testdata/upstream_configs.yml b/internal/devproxy/testdata/upstream_configs.yml new file mode 100644 index 00000000..8c8b4767 --- /dev/null +++ b/internal/devproxy/testdata/upstream_configs.yml @@ -0,0 +1,11 @@ +- service: devshim + default: + from: http://localhost:4888 + to: http://localhost:4810 + options: + skip_request_signing: false + user_info: + user: testUser + groups: team + email: testtest@remitly.com +