diff --git a/log/access/access_log.go b/log/access/access_log.go index 6626357..e306cb9 100644 --- a/log/access/access_log.go +++ b/log/access/access_log.go @@ -11,41 +11,9 @@ import ( "time" "resenje.org/logging" + "resenje.org/web" ) -type responseLogger struct { - w http.ResponseWriter - status int - size int -} - -func (l *responseLogger) Header() http.Header { - return l.w.Header() -} - -func (l *responseLogger) Flush() { - l.w.(http.Flusher).Flush() -} - -func (l *responseLogger) Push(target string, opts *http.PushOptions) error { - return l.w.(http.Pusher).Push(target, opts) -} - -func (l *responseLogger) Write(b []byte) (int, error) { - if l.status == 0 { - // The status will be StatusOK if WriteHeader has not been called yet - l.status = http.StatusOK - } - size, err := l.w.Write(b) - l.size += size - return size, err -} - -func (l *responseLogger) WriteHeader(s int) { - l.w.WriteHeader(s) - l.status = s -} - // NewHandler returns a handler that logs HTTP requests. // It logs information about remote address, X-Forwarded-For or X-Real-Ip, // HTTP method, request URI, HTTP protocol, HTTP response status, total bytes @@ -54,7 +22,7 @@ func (l *responseLogger) WriteHeader(s int) { func NewHandler(h http.Handler, logger *logging.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { startTime := time.Now() - rl := &responseLogger{w, 0, 0} + rl := web.NewResponseStatusRecorder(w) h.ServeHTTP(rl, r) referrer := r.Referer() if referrer == "" { @@ -77,19 +45,20 @@ func NewHandler(h http.Handler, logger *logging.Logger) http.Handler { if len(ips) > 0 { xips = strings.Join(ips, ", ") } + status := rl.Status() var level logging.Level switch { - case rl.status >= 500: + case status >= 500: level = logging.ERROR - case rl.status >= 400: + case status >= 400: level = logging.WARNING - case rl.status >= 300: + case status >= 300: level = logging.INFO - case rl.status >= 200: + case status >= 200: level = logging.INFO default: level = logging.DEBUG } - logger.Logf(level, "%s \"%s\" \"%v %s %v\" %d %d %f \"%s\" \"%s\"", r.RemoteAddr, xips, r.Method, r.RequestURI, r.Proto, rl.status, rl.size, time.Since(startTime).Seconds(), referrer, userAgent) + logger.Logf(level, "%s \"%s\" \"%v %s %v\" %d %d %f \"%s\" \"%s\"", r.RemoteAddr, xips, r.Method, r.RequestURI, r.Proto, status, rl.ResponseBodySize(), time.Since(startTime).Seconds(), referrer, userAgent) }) } diff --git a/log/access/access_log_test.go b/log/access/access_log_test.go index c04dca3..c9f1d24 100644 --- a/log/access/access_log_test.go +++ b/log/access/access_log_test.go @@ -29,9 +29,10 @@ func TestAccessLog(t *testing.T) { pattern *regexp.Regexp }{ { - name: "GET", - request: httptest.NewRequest("", "/", nil), - pattern: regexp.MustCompile(`^INFO 192.0.2.1:1234 "-" "GET / HTTP/1.1" 200 9 0.\d{6} "-" "-"$`), + name: "GET", + request: httptest.NewRequest("", "/", nil), + statusCode: http.StatusOK, + pattern: regexp.MustCompile(`^INFO 192.0.2.1:1234 "-" "GET / HTTP/1.1" 200 9 0.\d{6} "-" "-"$`), }, { name: "POST", @@ -112,5 +113,4 @@ func TestAccessLog(t *testing.T) { } }) } - } diff --git a/response_recorder.go b/response_recorder.go new file mode 100644 index 0000000..b8b1e21 --- /dev/null +++ b/response_recorder.go @@ -0,0 +1,58 @@ +// Copyright (c) 2021, Janoš Guljaš +// All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package web + +import "net/http" + +// ResponseStatusRecorder implements http.ResponseWriter that keeps tack of HTTP +// response status code and written body size in bytes. +type ResponseStatusRecorder struct { + http.ResponseWriter + status int + size int +} + +// NewResponseStatusRecorder wraps an http.ResponseWriter with +// ResponseStatusRecorder in order to record the status code and written body +// size. +func NewResponseStatusRecorder(w http.ResponseWriter) *ResponseStatusRecorder { + return &ResponseStatusRecorder{ + ResponseWriter: w, + } +} + +// Write implements http.ResponseWriter. +func (r *ResponseStatusRecorder) Write(b []byte) (int, error) { + size, err := r.ResponseWriter.Write(b) + if err != nil { + return 0, err + } + if r.status == 0 { + // The status will be StatusOK if WriteHeader has not been called yet + r.status = http.StatusOK + } + r.size += size + return size, err +} + +// WriteHeader implements http.ResponseWriter. +func (r *ResponseStatusRecorder) WriteHeader(s int) { + r.ResponseWriter.WriteHeader(s) + if r.status == 0 { + r.status = s + } +} + +// Status returns the responded status code. If it is 0, no response data has +// been written. +func (r *ResponseStatusRecorder) Status() int { + return r.status +} + +// ResponseBodySize returns the number of bytes that are written as the response body. +func (r *ResponseStatusRecorder) ResponseBodySize() int { + return r.size +} diff --git a/response_recorder_test.go b/response_recorder_test.go new file mode 100644 index 0000000..1933310 --- /dev/null +++ b/response_recorder_test.go @@ -0,0 +1,99 @@ +// Copyright (c) 2021, Janoš Guljaš +// All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package web_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "resenje.org/web" +) + +func TestResponseStatusRecorder_noWrite(t *testing.T) { + w := httptest.NewRecorder() + + rec := web.NewResponseStatusRecorder(w) + + if size := rec.ResponseBodySize(); size != 0 { + t.Errorf("got %v bytes that are written as body, want 0", size) + } + if status := rec.Status(); status != 0 { + t.Errorf("git status %v, want %v", status, 0) + } +} + +func TestResponseStatusRecorder_write(t *testing.T) { + w := httptest.NewRecorder() + + rec := web.NewResponseStatusRecorder(w) + + n, err := rec.Write([]byte("hi")) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Errorf("got %v bytes that are written, want 2", n) + } + if size := rec.ResponseBodySize(); size != 2 { + t.Errorf("got %v bytes that are written as body, want 2", size) + } + if status := rec.Status(); status != http.StatusOK { + t.Errorf("git status %v, want %v", status, http.StatusOK) + } + + n, err = rec.Write([]byte("hello")) + if err != nil { + t.Fatal(err) + } + if n != 5 { + t.Errorf("got %v bytes that are written, want 5", n) + } + if size := rec.ResponseBodySize(); size != 7 { + t.Errorf("got %v bytes that are written as body, want 7", size) + } +} + +func TestResponseStatusRecorder_writeHeader(t *testing.T) { + w := httptest.NewRecorder() + + rec := web.NewResponseStatusRecorder(w) + + rec.WriteHeader(http.StatusTeapot) + + if size := rec.ResponseBodySize(); size != 0 { + t.Errorf("got %v bytes that are written as body, want 0", size) + } + if status := rec.Status(); status != http.StatusTeapot { + t.Errorf("git status %v, want %v", status, http.StatusTeapot) + } +} + +func TestResponseStatusRecorder_writeHeaderAfterWrite(t *testing.T) { + w := httptest.NewRecorder() + + rec := web.NewResponseStatusRecorder(w) + + n, err := rec.Write([]byte("hi")) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Errorf("got %v bytes that are written, want 2", n) + } + if size := rec.ResponseBodySize(); size != 2 { + t.Errorf("got %v bytes that are written as body, want 2", size) + } + if status := rec.Status(); status != http.StatusOK { + t.Errorf("git status %v, want %v", status, http.StatusOK) + } + + rec.WriteHeader(http.StatusTeapot) + + if status := rec.Status(); status != http.StatusOK { + t.Errorf("git status %v, want %v", status, http.StatusOK) + } +}