Skip to content

Commit

Permalink
CSRF: Fix additional headers option (grafana#50629)
Browse files Browse the repository at this point in the history
* CSRF: Fix additional headers option

* fix: type assertion on error fail on wrapped error

* Update pkg/middleware/csrf/csrf_test.go

Co-authored-by: Emil Tullstedt <[email protected]>

* update test

Co-authored-by: eleijonmarck <[email protected]>
  • Loading branch information
sakjur and eleijonmarck authored Jul 13, 2022
1 parent ab6cf9e commit 06bd8b8
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 78 deletions.
2 changes: 1 addition & 1 deletion pkg/api/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
}

m.Use(middleware.Recovery(hs.Cfg))
m.UseMiddleware(hs.Csrf.Middleware(hs.log))
m.UseMiddleware(hs.Csrf.Middleware())

hs.mapStatic(m, hs.Cfg.StaticRootPath, "build", "public/build")
hs.mapStatic(m, hs.Cfg.StaticRootPath, "", "public", "/public/views/swagger.html")
Expand Down
181 changes: 106 additions & 75 deletions pkg/middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,112 +2,62 @@ package csrf

import (
"errors"
"fmt"
"net/http"
"net/url"
"reflect"

"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)

type Service interface {
Middleware(logger log.Logger) func(http.Handler) http.Handler
Middleware() func(http.Handler) http.Handler
TrustOrigin(origin string)
AddOriginHeader(headerName string)
AddAdditionalHeaders(headerName string)
AddSafeEndpoint(endpoint string)
}

type Implementation struct {
type CSRF struct {
cfg *setting.Cfg

trustedOrigins map[string]struct{}
originHeaders map[string]struct{}
headers map[string]struct{}
safeEndpoints map[string]struct{}
}

func ProvideCSRFFilter(cfg *setting.Cfg) Service {
i := &Implementation{
c := &CSRF{
cfg: cfg,
trustedOrigins: map[string]struct{}{},
originHeaders: map[string]struct{}{
"Origin": {},
},
safeEndpoints: map[string]struct{}{},
headers: map[string]struct{}{},
safeEndpoints: map[string]struct{}{},
}

additionalHeaders := cfg.SectionWithEnvOverrides("security").Key("csrf_additional_headers").Strings(" ")
trustedOrigins := cfg.SectionWithEnvOverrides("security").Key("csrf_trusted_origins").Strings(" ")

for _, header := range additionalHeaders {
i.originHeaders[header] = struct{}{}
c.headers[header] = struct{}{}
}
for _, origin := range trustedOrigins {
i.trustedOrigins[origin] = struct{}{}
c.trustedOrigins[origin] = struct{}{}
}

return i
return c
}

func (i *Implementation) Middleware(logger log.Logger) func(http.Handler) http.Handler {
// As per RFC 7231/4.2.2 these methods are idempotent:
// (GET is excluded because it may have side effects in some APIs)
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"}

func (c *CSRF) Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// If request has no login cookie - skip CSRF checks
if _, err := r.Cookie(i.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) {
next.ServeHTTP(w, r)
return
}
// Skip CSRF checks for "safe" methods
for _, method := range safeMethods {
if r.Method == method {
next.ServeHTTP(w, r)
return
}
}
// Skip CSRF checks for "safe" endpoints
for safeEndpoint := range i.safeEndpoints {
if r.URL.Path == safeEndpoint {
next.ServeHTTP(w, r)
return
}
}
// Otherwise - verify that Origin matches the server origin
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
origins := map[string]struct{}{}
for header := range i.originHeaders {
origin, err := url.Parse(r.Header.Get(header))
if err != nil {
logger.Error("error parsing Origin header", "header", header, "err", err)
}
if origin.String() != "" {
origins[origin.Hostname()] = struct{}{}
}
}

// No Origin header sent, skip CSRF check.
if len(origins) == 0 {
next.ServeHTTP(w, r)
return
}
e := &errorWithStatus{}

trustedOrigin := false
for o := range i.trustedOrigins {
if _, ok := origins[o]; ok {
trustedOrigin = true
break
err := c.check(r)
if err != nil {
if !errors.As(err, &e) {
http.Error(w, fmt.Sprintf("internal server error: expected error type errorWithStatus, got %s. Error: %v", reflect.TypeOf(err), err), http.StatusInternalServerError)
}
}

_, hostnameMatches := origins[netAddr.Host]
if netAddr.Host == "" || !trustedOrigin && !hostnameMatches {
http.Error(w, "origin not allowed", http.StatusForbidden)
http.Error(w, err.Error(), e.HTTPStatus)
return
}

Expand All @@ -116,15 +66,96 @@ func (i *Implementation) Middleware(logger log.Logger) func(http.Handler) http.H
}
}

func (i *Implementation) TrustOrigin(origin string) {
i.trustedOrigins[origin] = struct{}{}
func (c *CSRF) check(r *http.Request) error {
// As per RFC 7231/4.2.2 these methods are idempotent:
// (GET is excluded because it may have side effects in some APIs)
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"}

// If request has no login cookie - skip CSRF checks
if _, err := r.Cookie(c.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) {
return nil
}
// Skip CSRF checks for "safe" methods
for _, method := range safeMethods {
if r.Method == method {
return nil
}
}
// Skip CSRF checks for "safe" endpoints
for safeEndpoint := range c.safeEndpoints {
if r.URL.Path == safeEndpoint {
return nil
}
}
// Otherwise - verify that Origin matches the server origin
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
if err != nil {
return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest}
}

o := r.Header.Get("Origin")

// No Origin header sent, skip CSRF check.
if o == "" {
return nil
}

originURL, err := url.Parse(o)
if err != nil {
return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest}
}
origin := originURL.Hostname()

trustedOrigin := false
for h := range c.headers {
customHost := r.Header.Get(h)
addr, err := util.SplitHostPortDefault(customHost, "", "0") // we ignore the port
if err != nil {
return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest}
}
if addr.Host == origin {
trustedOrigin = true
break
}
}

for o := range c.trustedOrigins {
if o == origin {
trustedOrigin = true
break
}
}

hostnameMatches := origin == netAddr.Host
if netAddr.Host == "" || !trustedOrigin && !hostnameMatches {
return &errorWithStatus{Underlying: errors.New("origin not allowed"), HTTPStatus: http.StatusForbidden}
}

return nil
}

func (c *CSRF) TrustOrigin(origin string) {
c.trustedOrigins[origin] = struct{}{}
}

func (i *Implementation) AddOriginHeader(headerName string) {
i.originHeaders[headerName] = struct{}{}
func (c *CSRF) AddAdditionalHeaders(headerName string) {
c.headers[headerName] = struct{}{}
}

// AddSafeEndpoint is used for endpoints requests to skip CSRF check
func (i *Implementation) AddSafeEndpoint(endpoint string) {
i.safeEndpoints[endpoint] = struct{}{}
func (c *CSRF) AddSafeEndpoint(endpoint string) {
c.safeEndpoints[endpoint] = struct{}{}
}

type errorWithStatus struct {
Underlying error
HTTPStatus int
}

func (e errorWithStatus) Error() string {
return e.Underlying.Error()
}

func (e errorWithStatus) Unwrap() error {
return e.Underlying
}
117 changes: 115 additions & 2 deletions pkg/middleware/csrf/csrf_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package csrf

import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/setting"
)

Expand Down Expand Up @@ -100,6 +102,117 @@ func TestMiddlewareCSRF(t *testing.T) {
}
}

func TestCSRF_Check(t *testing.T) {
tests := []struct {
name string
request *http.Request
addtHeader map[string]struct{}
trustedOrigins map[string]struct{}
safeEndpoints map[string]struct{}
expectedOK bool
expectedStatus int
}{
{
name: "base case",
request: postRequest(t, "", nil),
expectedOK: true,
},
{
name: "base with null origin header",
request: postRequest(t, "", map[string]string{"Origin": "null"}),
expectedStatus: http.StatusForbidden,
},
{
name: "grafana.org",
request: postRequest(t, "grafana.org", map[string]string{"Origin": "https://grafana.org"}),
expectedOK: true,
},
{
name: "grafana.org with X-Forwarded-Host",
request: postRequest(t, "grafana.localhost", map[string]string{"X-Forwarded-Host": "grafana.org", "Origin": "https://grafana.org"}),
expectedStatus: http.StatusForbidden,
},
{
name: "grafana.org with X-Forwarded-Host and header trusted",
request: postRequest(t, "grafana.localhost", map[string]string{"X-Forwarded-Host": "grafana.org", "Origin": "https://grafana.org"}),
addtHeader: map[string]struct{}{"X-Forwarded-Host": {}},
expectedOK: true,
},
{
name: "grafana.org from grafana.com",
request: postRequest(t, "grafana.org", map[string]string{"Origin": "https://grafana.com"}),
expectedStatus: http.StatusForbidden,
},
{
name: "grafana.org from grafana.com explicit trust for grafana.com",
request: postRequest(t, "grafana.org", map[string]string{"Origin": "https://grafana.com"}),
trustedOrigins: map[string]struct{}{"grafana.com": {}},
expectedOK: true,
},
{
name: "grafana.org from grafana.com with X-Forwarded-Host and header trusted",
request: postRequest(t, "grafana.localhost", map[string]string{"X-Forwarded-Host": "grafana.org", "Origin": "https://grafana.com"}),
addtHeader: map[string]struct{}{"X-Forwarded-Host": {}},
trustedOrigins: map[string]struct{}{"grafana.com": {}},
expectedOK: true,
},
{
name: "safe endpoint",
request: postRequest(t, "example.org/foo/bar", map[string]string{"Origin": "null"}),
safeEndpoints: map[string]struct{}{"foo/bar": {}},
expectedOK: true,
},
}

for _, tc := range tests {
tc := tc

t.Run(tc.name, func(t *testing.T) {
c := CSRF{
cfg: setting.NewCfg(),
trustedOrigins: tc.trustedOrigins,
headers: tc.addtHeader,
safeEndpoints: tc.safeEndpoints,
}
c.cfg.LoginCookieName = "LoginCookie"

err := c.check(tc.request)
if tc.expectedOK {
require.NoError(t, err)
} else {
require.Error(t, err)
var actual *errorWithStatus
require.True(t, errors.As(err, &actual))
assert.EqualValues(t, tc.expectedStatus, actual.HTTPStatus)
}
})
}
}

func postRequest(t testing.TB, hostname string, headers map[string]string) *http.Request {
t.Helper()
urlParts := strings.SplitN(hostname, "/", 2)

path := "/"
if len(urlParts) == 2 {
path = urlParts[1]
}
r, err := http.NewRequest(http.MethodPost, path, nil)
require.NoError(t, err)

r.Host = urlParts[0]

r.AddCookie(&http.Cookie{
Name: "LoginCookie",
Value: "this should not be important",
})

for k, v := range headers {
r.Header.Set(k, v)
}
return r
}

func csrfScenario(t *testing.T, cookieName, method, origin, host string) *httptest.ResponseRecorder {
req, err := http.NewRequest(method, "/", nil)
if err != nil {
Expand All @@ -123,7 +236,7 @@ func csrfScenario(t *testing.T, cookieName, method, origin, host string) *httpte
cfg := setting.NewCfg()
cfg.LoginCookieName = cookieName
service := ProvideCSRFFilter(cfg)
handler := service.Middleware(log.New())(testHandler)
handler := service.Middleware()(testHandler)
handler.ServeHTTP(rr, req)
return rr
}

0 comments on commit 06bd8b8

Please sign in to comment.