diff --git a/contrib/internal/httptrace/make_responsewriter.go b/contrib/internal/httptrace/make_responsewriter.go index da08b62d99..2458ef7fbf 100644 --- a/contrib/internal/httptrace/make_responsewriter.go +++ b/contrib/internal/httptrace/make_responsewriter.go @@ -65,6 +65,8 @@ func wrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *responseWr type monitoredResponseWriter interface { http.ResponseWriter Status() int + Block() + Blocked() bool Unwrap() http.ResponseWriter } switch { diff --git a/contrib/internal/httptrace/response_writer.go b/contrib/internal/httptrace/response_writer.go index 2bbc31bad7..964a9a83a9 100644 --- a/contrib/internal/httptrace/response_writer.go +++ b/contrib/internal/httptrace/response_writer.go @@ -7,17 +7,31 @@ package httptrace //go:generate sh -c "go run make_responsewriter.go | gofmt > trace_gen.go" -import "net/http" +import ( + "net/http" + "sync" + + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/log" +) + +var warnLogOnce sync.Once + +const warnLogMsg = `appsec: http.ResponseWriter was used after a security blocking decision was enacted. +Please check for gopkg.in/DataDog/dd-trace-go.v1/appsec/events.BlockingSecurityEvent in the error result value of instrumented functions.` + +// TODO(eliott.bouhana): add a link to the appsec SDK documentation ^^^ here ^^^ // responseWriter is a small wrapper around an http response writer that will // intercept and store the status of a request. type responseWriter struct { http.ResponseWriter - status int + status int + blocked bool } func newResponseWriter(w http.ResponseWriter) *responseWriter { - return &responseWriter{w, 0} + return &responseWriter{w, 0, false} } // Status returns the status code that was monitored. @@ -25,10 +39,31 @@ func (w *responseWriter) Status() int { return w.status } +// Blocked returns whether the response has been blocked. +func (w *responseWriter) Blocked() bool { + return w.blocked +} + +// Block is supposed only once, after a response (one made by appsec code) as been sent. If it not the case, the function will do nothing. +// All subsequent calls to Write and WriteHeader will be trigger a log warning users that the response has been blocked. +func (w *responseWriter) Block() { + if !appsec.Enabled() || w.status == 0 { + return + } + + w.blocked = true +} + // Write writes the data to the connection as part of an HTTP reply. // We explicitly call WriteHeader with the 200 status code // in order to get it reported into the span. func (w *responseWriter) Write(b []byte) (int, error) { + if w.blocked { + warnLogOnce.Do(func() { + log.Warn(warnLogMsg) + }) + return len(b), nil + } if w.status == 0 { w.WriteHeader(http.StatusOK) } @@ -38,11 +73,16 @@ func (w *responseWriter) Write(b []byte) (int, error) { // WriteHeader sends an HTTP response header with status code. // It also sets the status code to the span. func (w *responseWriter) WriteHeader(status int) { - if w.status != 0 { + if w.blocked { + warnLogOnce.Do(func() { + log.Warn(warnLogMsg) + }) return } + if w.status == 0 { + w.status = status + } w.ResponseWriter.WriteHeader(status) - w.status = status } // Unwrap returns the underlying wrapped http.ResponseWriter. diff --git a/contrib/internal/httptrace/response_writer_test.go b/contrib/internal/httptrace/response_writer_test.go index 78d5ffc6e2..c8f9d69a9d 100644 --- a/contrib/internal/httptrace/response_writer_test.go +++ b/contrib/internal/httptrace/response_writer_test.go @@ -7,9 +7,12 @@ package httptrace import ( "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" ) func Test_wrapResponseWriter(t *testing.T) { @@ -32,5 +35,40 @@ func Test_wrapResponseWriter(t *testing.T) { _, ok = w.(http.Pusher) assert.True(t, ok) }) +} + +func TestBlock(t *testing.T) { + appsec.Start() + defer appsec.Stop() + + if !appsec.Enabled() { + t.Skip("appsec is not enabled") + } + + t.Run("block-before-first-write", func(t *testing.T) { + recorder := httptest.NewRecorder() + rw := newResponseWriter(recorder) + rw.Block() + assert.False(t, rw.blocked) + + rw.WriteHeader(http.StatusForbidden) + rw.Block() + assert.True(t, rw.blocked) + + assert.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("write-after-block", func(t *testing.T) { + recorder := httptest.NewRecorder() + rw := newResponseWriter(recorder) + rw.WriteHeader(http.StatusForbidden) + rw.Write([]byte("foo")) + rw.Block() + rw.WriteHeader(http.StatusOK) + rw.Write([]byte("bar")) + + assert.Equal(t, http.StatusForbidden, recorder.Code) + assert.Equal(t, recorder.Body.String(), "foo") + }) } diff --git a/contrib/internal/httptrace/trace_gen.go b/contrib/internal/httptrace/trace_gen.go index 24e261838e..36a7f6f2bf 100644 --- a/contrib/internal/httptrace/trace_gen.go +++ b/contrib/internal/httptrace/trace_gen.go @@ -31,6 +31,8 @@ func wrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *responseWr type monitoredResponseWriter interface { http.ResponseWriter Status() int + Block() + Blocked() bool Unwrap() http.ResponseWriter } switch { diff --git a/contrib/labstack/echo.v4/appsec.go b/contrib/labstack/echo.v4/appsec.go index 9cd849cc20..14d7eee319 100644 --- a/contrib/labstack/echo.v4/appsec.go +++ b/contrib/labstack/echo.v4/appsec.go @@ -21,7 +21,10 @@ func withAppSec(next echo.HandlerFunc, span tracer.Span) echo.HandlerFunc { for _, n := range c.ParamNames() { params[n] = c.Param(n) } - var err error + var ( + err error + writer = &statusResponseWriter{Response: c.Response()} + ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c.SetRequest(r) err = next(c) @@ -32,7 +35,7 @@ func withAppSec(next echo.HandlerFunc, span tracer.Span) echo.HandlerFunc { } }) // Wrap the echo response to allow monitoring of the response status code in httpsec.WrapHandler() - httpsec.WrapHandler(handler, span, params, nil).ServeHTTP(&statusResponseWriter{Response: c.Response()}, c.Request()) + httpsec.WrapHandler(handler, span, params, nil).ServeHTTP(, c.Request()) // If an error occurred, wrap it under an echo.HTTPError. We need to do this so that APM doesn't override // the response code tag with 500 in case it doesn't recognize the error type. if _, ok := err.(*echo.HTTPError); !ok && err != nil { diff --git a/contrib/net/http/make_responsewriter.go b/contrib/net/http/make_responsewriter.go deleted file mode 100644 index 13ac9a8a14..0000000000 --- a/contrib/net/http/make_responsewriter.go +++ /dev/null @@ -1,88 +0,0 @@ -// Unless explicitly stated otherwise all files in this repository are licensed -// under the Apache License Version 2.0. -// This product includes software developed at Datadog (https://www.datadoghq.com/). -// Copyright 2016 Datadog, Inc. - -//go:build ignore -// +build ignore - -// This program generates wrapper implementations of http.ResponseWriter that -// also satisfy http.Flusher, http.Pusher, http.CloseNotifier and http.Hijacker, -// based on whether or not the passed in http.ResponseWriter also satisfies -// them. - -package main - -import ( - "os" - "text/template" - - "gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/lists" -) - -func main() { - interfaces := []string{"Flusher", "Pusher", "CloseNotifier", "Hijacker"} - var combos [][][]string - for pick := len(interfaces); pick > 0; pick-- { - combos = append(combos, lists.Combinations(interfaces, pick)) - } - template.Must(template.New("").Parse(tpl)).Execute(os.Stdout, map[string]interface{}{ - "Interfaces": interfaces, - "Combinations": combos, - }) -} - -var tpl = `// Unless explicitly stated otherwise all files in this repository are licensed -// under the Apache License Version 2.0. -// This product includes software developed at Datadog (https://www.datadoghq.com/). -// Copyright 2016 Datadog, Inc. - -// Code generated by make_responsewriter.go DO NOT EDIT - -package http - -import "net/http" - - -// wrapResponseWriter wraps an underlying http.ResponseWriter so that it can -// trace the http response codes. It also checks for various http interfaces -// (Flusher, Pusher, CloseNotifier, Hijacker) and if the underlying -// http.ResponseWriter implements them it generates an unnamed struct with the -// appropriate fields. -// -// This code is generated because we have to account for all the permutations -// of the interfaces. -// -// In case of any new interfaces or methods we didn't consider here, we also -// implement the rwUnwrapper interface, which is used internally by -// the standard library: https://github.com/golang/go/blob/6d89b38ed86e0bfa0ddaba08dc4071e6bb300eea/src/net/http/responsecontroller.go#L42-L44 -func wrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *responseWriter) { -{{- range .Interfaces }} - h{{.}}, ok{{.}} := w.(http.{{.}}) -{{- end }} - - mw := newResponseWriter(w) - type monitoredResponseWriter interface { - http.ResponseWriter - Status() int - Unwrap() http.ResponseWriter - } - switch { -{{- range .Combinations }} - {{- range . }} - case {{ range $i, $v := . }}{{ if gt $i 0 }} && {{ end }}ok{{ $v }}{{ end }}: - w = struct { - monitoredResponseWriter - {{- range . }} - http.{{.}} - {{- end }} - }{mw{{ range . }}, h{{.}}{{ end }}} - {{- end }} -{{- end }} - default: - w = mw - } - - return w, mw -} -` diff --git a/internal/appsec/emitter/httpsec/http.go b/internal/appsec/emitter/httpsec/http.go index 07f6dd9ba8..6adcfbf8fa 100644 --- a/internal/appsec/emitter/httpsec/http.go +++ b/internal/appsec/emitter/httpsec/http.go @@ -56,20 +56,31 @@ type ( func (HandlerOperationArgs) IsArgOf(*HandlerOperation) {} func (HandlerOperationRes) IsResultOf(*HandlerOperation) {} -func StartOperation(ctx context.Context, args HandlerOperationArgs) (*HandlerOperation, *atomic.Pointer[actions.BlockHTTP], context.Context) { - wafOp, ctx := waf.StartContextOperation(ctx) +func StartOperation(w http.ResponseWriter, r *http.Request, pathParams map[string]string, opts *Config) (*HandlerOperation, *atomic.Bool, context.Context) { + wafOp, ctx := waf.StartContextOperation(r.Context()) op := &HandlerOperation{ Operation: dyngo.NewOperation(wafOp), ContextOperation: wafOp, } - // We need to use an atomic pointer to store the action because the action may be created asynchronously in the future - var action atomic.Pointer[actions.BlockHTTP] + var blocked atomic.Bool dyngo.OnData(op, func(a *actions.BlockHTTP) { - action.Store(a) + a.Handler.ServeHTTP(w, r) + for _, f := range opts.OnBlock { + f() + } }) - return op, &action, dyngo.StartAndRegisterOperation(ctx, op, args) + return op, &blocked, dyngo.StartAndRegisterOperation(ctx, op, HandlerOperationArgs{ + Method: r.Method, + RequestURI: r.RequestURI, + Host: r.Host, + RemoteAddr: r.RemoteAddr, + Headers: r.Header, + Cookies: makeCookies(r.Cookies()), + QueryParams: r.URL.Query(), + PathParams: pathParams, + }) } // Finish the HTTP handler operation and its children operations and write everything to the service entry span. @@ -125,18 +136,7 @@ func BeforeHandle( opts.ResponseHeaderCopier = defaultWrapHandlerConfig.ResponseHeaderCopier } - op, blockAtomic, ctx := StartOperation(r.Context(), HandlerOperationArgs{ - Method: r.Method, - RequestURI: r.RequestURI, - Host: r.Host, - RemoteAddr: r.RemoteAddr, - Headers: r.Header, - Cookies: makeCookies(r.Cookies()), - QueryParams: r.URL.Query(), - PathParams: pathParams, - }) - tr := r.WithContext(ctx) - var blocked atomic.Bool + op, blocked, ctx := StartOperation(w, r, pathParams, opts) afterHandle := func() { var statusCode int @@ -147,28 +147,9 @@ func BeforeHandle( Headers: opts.ResponseHeaderCopier(w), StatusCode: statusCode, }, span) - - if blockPtr := blockAtomic.Swap(nil); blockPtr != nil { - blockPtr.Handler.ServeHTTP(w, tr) - blocked.Store(true) - } - - // Execute the onBlock functions to make sure blocking works properly - // in case we are instrumenting the Gin framework - if blocked.Load() { - for _, f := range opts.OnBlock { - f() - } - } - } - - if blockPtr := blockAtomic.Swap(nil); blockPtr != nil { - // handler is replaced - blockPtr.Handler.ServeHTTP(w, tr) - blocked.Store(true) } - return w, tr, afterHandle, blocked.Load() + return w, r.WithContext(ctx), afterHandle, blocked.Load() } // WrapHandler wraps the given HTTP handler with the abstract HTTP operation defined by HandlerOperationArgs and @@ -177,7 +158,6 @@ func BeforeHandle( // It is a specific patch meant for Gin, for which we must abort the // context since it uses a queue of handlers and it's the only way to make // sure other queued handlers don't get executed. -// TODO: this patch must be removed/improved when we rework our actions/operations system func WrapHandler(handler http.Handler, span ddtrace.Span, pathParams map[string]string, opts *Config) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tw, tr, afterHandle, handled := BeforeHandle(w, r, span, pathParams, opts) diff --git a/internal/appsec/emitter/waf/actions/actions.go b/internal/appsec/emitter/waf/actions/actions.go index 4eabcfaff6..1f9cd30bd5 100644 --- a/internal/appsec/emitter/waf/actions/actions.go +++ b/internal/appsec/emitter/waf/actions/actions.go @@ -25,7 +25,7 @@ var actionHandlers = map[string]actionHandler{} func registerActionHandler(aType string, handler actionHandler) { if _, ok := actionHandlers[aType]; ok { - log.Warn("appsec: action type `%s` already registered", aType) + log.Debug("appsec: action type `%s` already registered", aType) return } actionHandlers[aType] = handler diff --git a/internal/appsec/emitter/waf/actions/actions_test.go b/internal/appsec/emitter/waf/actions/actions_test.go index 1e77c8e2d1..434c091057 100644 --- a/internal/appsec/emitter/waf/actions/actions_test.go +++ b/internal/appsec/emitter/waf/actions/actions_test.go @@ -17,9 +17,9 @@ import ( func TestNewHTTPBlockRequestAction(t *testing.T) { mux := http.NewServeMux() srv := httptest.NewServer(mux) - mux.HandleFunc("/json", newHTTPBlockRequestAction(403, "json").ServeHTTP) - mux.HandleFunc("/html", newHTTPBlockRequestAction(403, "html").ServeHTTP) - mux.HandleFunc("/auto", newHTTPBlockRequestAction(403, "auto").ServeHTTP) + mux.HandleFunc("/json", newHTTPBlockRequestAction(403, BlockingTemplateJSON).ServeHTTP) + mux.HandleFunc("/html", newHTTPBlockRequestAction(403, BlockingTemplateHTML).ServeHTTP) + mux.HandleFunc("/auto", newHTTPBlockRequestAction(403, BlockingTemplateAuto).ServeHTTP) defer srv.Close() t.Run("json", func(t *testing.T) { diff --git a/internal/appsec/emitter/waf/actions/block.go b/internal/appsec/emitter/waf/actions/block.go index ae802b60bd..f2c86a2b49 100644 --- a/internal/appsec/emitter/waf/actions/block.go +++ b/internal/appsec/emitter/waf/actions/block.go @@ -47,15 +47,23 @@ func init() { registerActionHandler("block_request", NewBlockAction) } +const ( + BlockingTemplateJSON blockingTemplateType = "json" + BlockingTemplateHTML blockingTemplateType = "html" + BlockingTemplateAuto blockingTemplateType = "auto" +) + type ( + blockingTemplateType string + // blockActionParams are the dynamic parameters to be provided to a "block_request" // action type upon invocation blockActionParams struct { // GRPCStatusCode is the gRPC status code to be returned. Since 0 is the OK status, the value is nullable to // be able to distinguish between unset and defaulting to Abort (10), or set to OK (0). - GRPCStatusCode *int `mapstructure:"grpc_status_code,omitempty"` - StatusCode int `mapstructure:"status_code"` - Type string `mapstructure:"type,omitempty"` + GRPCStatusCode *int `mapstructure:"grpc_status_code,omitempty"` + StatusCode int `mapstructure:"status_code"` + Type blockingTemplateType `mapstructure:"type,omitempty"` } // GRPCWrapper is an opaque prototype abstraction for a gRPC handler (to avoid importing grpc) // that returns a status code and an error @@ -70,6 +78,12 @@ type ( BlockHTTP struct { http.Handler } + + HTTPBlockHandlerConfig struct { + Template []byte + ContentType string + StatusCode int + } ) func (a *BlockGRPC) EmitData(op dyngo.Operation) { @@ -83,32 +97,28 @@ func (a *BlockHTTP) EmitData(op dyngo.Operation) { } func newGRPCBlockRequestAction(status int) *BlockGRPC { - return &BlockGRPC{GRPCWrapper: newGRPCBlockHandler(status)} -} - -func newGRPCBlockHandler(status int) GRPCWrapper { - return func() (uint32, error) { + return &BlockGRPC{GRPCWrapper: func() (uint32, error) { return uint32(status), &events.BlockingSecurityEvent{} - } + }} } func blockParamsFromMap(params map[string]any) (blockActionParams, error) { grpcCode := 10 - p := blockActionParams{ - Type: "auto", + parsedParams := blockActionParams{ + Type: BlockingTemplateAuto, StatusCode: 403, GRPCStatusCode: &grpcCode, } - if err := mapstructure.WeakDecode(params, &p); err != nil { - return p, err + if err := mapstructure.WeakDecode(params, &parsedParams); err != nil { + return parsedParams, err } - if p.GRPCStatusCode == nil { - p.GRPCStatusCode = &grpcCode + if parsedParams.GRPCStatusCode == nil { + parsedParams.GRPCStatusCode = &grpcCode } - return p, nil + return parsedParams, nil } // NewBlockAction creates an action for the "block_request" action type @@ -124,38 +134,83 @@ func NewBlockAction(params map[string]any) []Action { } } -func newHTTPBlockRequestAction(status int, template string) *BlockHTTP { - return &BlockHTTP{Handler: newBlockHandler(status, template)} +func newHTTPBlockRequestAction(statusCode int, template blockingTemplateType) *BlockHTTP { + return &BlockHTTP{Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + template := template + if template == BlockingTemplateAuto { + template = blockingTemplateTypeFromHeaders(request.Header) + } + + if code, found := UnwrapGetStatusCode(writer); found && code != 0 { + // The status code has already been set, so we can't change it, do nothing + return + } + + if blocker, found := UnwrapBlocker(writer); found { + // We found our custom response writer, so we can block futur calls to Write and WriteHeader + defer blocker() + } + + writer.Header().Set("Content-Type", template.ContentType()) + writer.WriteHeader(statusCode) + writer.Write(template.Template()) + })} } -// newBlockHandler creates, initializes and returns a new BlockRequestAction -func newBlockHandler(status int, template string) http.Handler { - htmlHandler := newBlockRequestHandler(status, "text/html", blockedTemplateHTML) - jsonHandler := newBlockRequestHandler(status, "application/json", blockedTemplateJSON) - switch template { - case "json": - return jsonHandler - case "html": - return htmlHandler - default: - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := jsonHandler - hdr := r.Header.Get("Accept") - htmlIdx := strings.Index(hdr, "text/html") - jsonIdx := strings.Index(hdr, "application/json") - // Switch to html handler if text/html comes before application/json in the Accept header - if htmlIdx != -1 && (jsonIdx == -1 || htmlIdx < jsonIdx) { - h = htmlHandler - } - h.ServeHTTP(w, r) - }) +func blockingTemplateTypeFromHeaders(headers http.Header) blockingTemplateType { + hdr := headers.Get("Accept") + htmlIdx := strings.Index(hdr, "text/html") + jsonIdx := strings.Index(hdr, "application/json") + // Switch to html handler if text/html comes before application/json in the Accept header + if htmlIdx != -1 && (jsonIdx == -1 || htmlIdx < jsonIdx) { + return BlockingTemplateHTML } + + return BlockingTemplateJSON +} + +func (typ blockingTemplateType) Template() []byte { + if typ == BlockingTemplateHTML { + return blockedTemplateHTML + } + + return blockedTemplateJSON +} + +func (typ blockingTemplateType) ContentType() string { + if typ == BlockingTemplateHTML { + return "text/html" + } + + return "application/json" +} + +// UnwrapBlocker unwraps the right struct method from contrib/internal/httptrace.responseWriter +// and returns the Block() function and if it was found. +func UnwrapBlocker(writer http.ResponseWriter) (func(), bool) { + // this is part of the contrib/internal/httptrace.responseWriter interface + wrapped, ok := writer.(interface { + Block() + }) + if !ok { + // Somehow we can't access the wrapped response writer, so we can't block the response + return nil, false + } + + return wrapped.Block, true } -func newBlockRequestHandler(status int, ct string, payload []byte) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", ct) - w.WriteHeader(status) - w.Write(payload) +// UnwrapGetStatusCode unwraps the right struct method from contrib/internal/httptrace.responseWriter +// and calls it to know if a call to WriteHeader has been made and returns the status code. +func UnwrapGetStatusCode(writer http.ResponseWriter) (int, bool) { + // this is part of the contrib/internal/httptrace.responseWriter interface + wrapped, ok := writer.(interface { + Status() int }) + if !ok { + // Somehow we can't access the wrapped response writer, so we can't get the status code + return 0, false + } + + return wrapped.Status(), true } diff --git a/internal/appsec/emitter/waf/actions/http_redirect.go b/internal/appsec/emitter/waf/actions/http_redirect.go index 3cdca4c818..8bbb8db0a7 100644 --- a/internal/appsec/emitter/waf/actions/http_redirect.go +++ b/internal/appsec/emitter/waf/actions/http_redirect.go @@ -25,9 +25,9 @@ func init() { } func redirectParamsFromMap(params map[string]any) (redirectActionParams, error) { - var p redirectActionParams - err := mapstructure.WeakDecode(params, &p) - return p, err + var parsedParams redirectActionParams + err := mapstructure.WeakDecode(params, &parsedParams) + return parsedParams, err } func newRedirectRequestAction(status int, loc string) *BlockHTTP { @@ -38,9 +38,24 @@ func newRedirectRequestAction(status int, loc string) *BlockHTTP { // If location is not set we fall back on a default block action if loc == "" { - return &BlockHTTP{Handler: newBlockHandler(http.StatusForbidden, string(blockedTemplateJSON))} + return newHTTPBlockRequestAction(http.StatusForbidden, BlockingTemplateAuto) } - return &BlockHTTP{Handler: http.RedirectHandler(loc, status)} + + redirectHandler := http.RedirectHandler(loc, status) + return &BlockHTTP{Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if UnwrapGetStatusCode(writer) != 0 { + // The status code has already been set, so we can't change it, do nothing + return + } + + blocker, found := UnwrapBlocker(writer) + if found { + // We found our custom response writer, so we can block futur calls to Write and WriteHeader + defer blocker() + } + + redirectHandler.ServeHTTP(writer, request) + })} } // NewRedirectAction creates an action for the "redirect_request" action type diff --git a/internal/appsec/waf_test.go b/internal/appsec/waf_test.go index 7e169ffb7a..8b32f2dc83 100644 --- a/internal/appsec/waf_test.go +++ b/internal/appsec/waf_test.go @@ -335,15 +335,22 @@ func TestBlocking(t *testing.T) { w.Write([]byte("Hello World!\n")) }) mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) { - if err := pAppsec.SetUser(r.Context(), r.Header.Get("test-usr")); err != nil { + if r.Header.Get("write-before-block") != "" { + w.WriteHeader(204) + } + + if err := pAppsec.SetUser(r.Context(), r.Header.Get("test-usr")); err != nil && r.Header.Get("write-after-block") == "" { return } w.Write([]byte("Hello World!\n")) }) mux.HandleFunc("/body", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("write-before-block") != "" { + w.WriteHeader(204) + } buf := new(strings.Builder) io.Copy(buf, r.Body) - if err := pAppsec.MonitorParsedHTTPBody(r.Context(), buf.String()); err != nil { + if err := pAppsec.MonitorParsedHTTPBody(r.Context(), buf.String()); err != nil && r.Header.Get("write-after-block") == "" { return } w.Write([]byte("Hello World!\n")) @@ -395,6 +402,20 @@ func TestBlocking(t *testing.T) { status: 403, ruleMatch: userBlockingRule, }, + { + name: "user/no-write-after-block", + headers: map[string]string{"test-usr": "blocked-user-1", "write-after-block": "true"}, + endpoint: "/user", + status: 403, + ruleMatch: userBlockingRule, + }, + { + name: "user/cannot-block-because-write-before-block", + headers: map[string]string{"test-usr": "blocked-user-1", "write-before-block": "true"}, + endpoint: "/user", + status: 204, + ruleMatch: userBlockingRule, + }, // This test checks that IP blocking happens BEFORE user blocking, since user blocking needs the request handler // to be invoked while IP blocking doesn't { @@ -417,6 +438,22 @@ func TestBlocking(t *testing.T) { reqBody: "$globals", ruleMatch: bodyBlockingRule, }, + { + name: "body/no-write-after-block", + headers: map[string]string{"write-after-block": "true"}, + endpoint: "/body", + status: 403, + reqBody: "$globals", + ruleMatch: bodyBlockingRule, + }, + { + name: "body/cannot-block-because-write-before-block", + headers: map[string]string{"write-before-block": "true"}, + endpoint: "/body", + status: 204, + reqBody: "$globals", + ruleMatch: bodyBlockingRule, + }, } { t.Run(tc.name, func(t *testing.T) { mt := mocktracer.Start() @@ -430,12 +467,17 @@ func TestBlocking(t *testing.T) { require.NoError(t, err) defer res.Body.Close() require.Equal(t, tc.status, res.StatusCode) - b, err := io.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) require.NoError(t, err) - if tc.status == 200 { - require.Equal(t, "Hello World!\n", string(b)) - } else { - require.NotEqual(t, "Hello World!\n", string(b)) + switch tc.status { + case 200: + require.Equal(t, "Hello World!\n", string(body)) + case 204: + require.Empty(t, string(body)) + case 403: + require.Contains(t, string(body), "Security provided by Datadog") + default: + panic("unexpected status code") } if tc.ruleMatch != "" { spans := mt.FinishedSpans()