forked from socialpoint-labs/bsk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecorator.go
188 lines (162 loc) · 5.48 KB
/
decorator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package httpx
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
)
// Decorator wraps/decorate a http.Handler with additional functionality.
type Decorator func(http.Handler) http.Handler
// AddHeaderDecorator returns a decorator that adds the given header to the HTTP response.
func AddHeaderDecorator(key, value string) Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(key, value)
h.ServeHTTP(w, r)
})
}
}
// SetHeaderDecorator returns a decorator that sets the given header to the HTTP response.
func SetHeaderDecorator(key, value string) Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(key, value)
h.ServeHTTP(w, r)
})
}
}
// CheckHeaderDecorator returns a decorator that checks if the given request header
// matches the given value, if the header does not exist or doesn't match then
// respond with the provided status code header and its value as content.
func CheckHeaderDecorator(headerName, headerValue string, statusCode int) Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
value := r.Header.Get(headerName)
if value != headerValue {
w.WriteHeader(statusCode)
// we don't care about the error if we can't write
_, _ = w.Write([]byte(http.StatusText(statusCode)))
return
}
h.ServeHTTP(w, r)
})
}
}
// RootDecorator decorates a handler to distinguish root path from 404s
// ServeMux matches "/" for both, root path and all unmatched URLs
// How to bypass: https://golang.org/pkg/net/http/#example_ServeMux_Handle
func RootDecorator() Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
h.ServeHTTP(w, r)
})
}
}
// StripPrefixDecorator removes prefix from URL.
func StripPrefixDecorator(prefix string) Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) {
r.URL.Path = p
h.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
})
}
}
// EnableCORSDecorator adds required response headers to enable CORS and serves OPTIONS requests.
func EnableCORSDecorator() Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,HEAD,OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Origin,Accept,Content-Type,Authorization")
// Stop here if its Preflighted OPTIONS request
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
h.ServeHTTP(w, r)
})
}
}
// Condition represents a condition based on the http.Request and the current state of the http.ResponseWriter
type Condition func(w http.ResponseWriter, r *http.Request) bool
// IfDecorator is a special adapter that will skip to the 'then' handler if a condition
// applies at runtime, or pass the control to the adapted handler otherwise.
func IfDecorator(cond Condition, then http.Handler) Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if cond(w, r) {
then.ServeHTTP(w, r)
} else {
h.ServeHTTP(w, r)
}
})
}
}
// TimeoutDecorator returns a adapter which adds a timeout to the context.
// Child handlers have the responsibility to obey the context deadline
func TimeoutDecorator(timeout time.Duration) Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}
}
// LoggingDecorator returns an adapter that log requests to a given logger
func LoggingDecorator(logWriter io.Writer) Decorator {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
resLogger := &responseLogger{w, 0, 0}
h.ServeHTTP(resLogger, req)
_, _ = fmt.Fprintln(logWriter, formatLogLine(req, time.Now(), resLogger.Status(), resLogger.Size()))
})
}
}
func formatLogLine(req *http.Request, ts time.Time, status int, size int) string {
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
host = req.RemoteAddr
}
uri := req.URL.RequestURI()
formattedTime := ts.Format("02/Jan/2006:15:04:05 -0700")
return fmt.Sprintf("%s [%s] %s %s %s %d %d", host, formattedTime, req.Method, uri, req.Proto, status, size)
}
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP
// status code and body size
type responseLogger struct {
http.ResponseWriter
status int
size int
}
func (l *responseLogger) Write(b []byte) (int, error) {
if l.status == 0 {
// The status will be StatusOK if WriteHeader has not been called yet
l.status = http.StatusOK
}
size, err := l.ResponseWriter.Write(b)
l.size += size
return size, err
}
func (l *responseLogger) WriteHeader(s int) {
l.ResponseWriter.WriteHeader(s)
l.status = s
}
func (l responseLogger) Status() int {
return l.status
}
func (l responseLogger) Size() int {
return l.size
}