-
Notifications
You must be signed in to change notification settings - Fork 1
/
proxy.go
105 lines (96 loc) · 2.57 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package main
import (
"log"
"net/http"
"strings"
"sync"
)
// Proxy contains and servers the handlers for each hostname
type Proxy struct {
hostMap sync.Map
}
// Host contains the settings for each host
type Host struct {
Target string
SetCookiePath bool
}
// Handle adds a handler if it doesn't exist
func (proxy *Proxy) Handle(host string, handler *ProxyHandler) {
proxy.hostMap.Store(host, handler)
}
// Exists returns whether there is an
func (proxy *Proxy) Exists(host, target string) bool {
item, ok := proxy.hostMap.Load(host)
if !ok {
return false
}
return item.(*ProxyHandler).TargetName == target
}
// ServeHTTP finds the handler if one exists and then returns the result
func (proxy *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Match to hostname
result, ok := proxy.hostMap.Load(r.Host)
if ok {
// Found a handler so serve
handler := result.(*ProxyHandler)
handler.Handler.ServeHTTP(w, r)
return
}
// Match against the path prefix
url := strings.Split(r.RequestURI, "/")
if len(url) > 1 {
result, ok = proxy.hostMap.Load("/" + url[1])
if ok {
// Found a handler so serve
handler := result.(*ProxyHandler)
handler.Handler.ServeHTTP(w, r)
return
}
}
// Hostname doesn't match so try wildcard
result, ok = proxy.hostMap.Load("any")
if ok {
// Found a wildcard handler
handler := result.(*ProxyHandler)
handler.Handler.ServeHTTP(w, r)
} else {
http.Error(w, "Not found", 404)
}
}
type proxyTransport struct {
SetCookiePath bool
CapturedTransport http.RoundTripper
}
func (t *proxyTransport) RoundTrip(r *http.Request) (*http.Response, error) {
// Use the real transport to execute the request
response, err := transport.RoundTrip(r)
if err != nil {
transport.(*http.Transport).CloseIdleConnections()
log.Print("Unable to get response from target server: " + err.Error())
return nil, err
}
if response.StatusCode >= 500 {
transport.(*http.Transport).CloseIdleConnections()
}
if t.SetCookiePath {
for name, values := range response.Header {
if strings.EqualFold(name, "SET-COOKIE") {
// Remove the current SET-COOKIE headers
response.Header.Del("SET-COOKIE")
for _, value := range values {
parts := strings.Split(value, ";")
// Update the cookie with a root path
newSetCookie := parts[0] + "; Path=/"
if len(parts) > 2 {
newSetCookie += ";" + parts[2]
}
response.Header.Add("SET-COOKIE", newSetCookie)
}
}
}
}
if args.HSTS {
response.Header.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
}
return response, err
}