forked from TykTechnologies/tyk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
handler_websocket.go
128 lines (109 loc) · 2.67 KB
/
handler_websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package main
import (
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"net/url"
"strings"
"github.com/Sirupsen/logrus"
)
func canonicalAddr(url *url.URL) string {
addr := url.Host
// If the addr has a port number attached
if !(strings.LastIndex(addr, ":") > strings.LastIndex(addr, "]")) {
return addr + ":80"
}
return addr
}
type WSDialer struct {
*TykTransporter
RW http.ResponseWriter
TLSConfig *tls.Config
}
func (ws *WSDialer) RoundTrip(req *http.Request) (*http.Response, error) {
if !globalConf.HttpServerOptions.EnableWebSockets {
return nil, errors.New("WebSockets has been disabled on this host")
}
target := canonicalAddr(req.URL)
// TLS
dial := ws.Dial
if dial == nil {
dial = net.Dial
}
// We do not get this WSS scheme, need another way to identify it
switch req.URL.Scheme {
case "wss", "https":
var tlsConfig *tls.Config
if ws.TLSClientConfig == nil {
tlsConfig = &tls.Config{}
} else {
tlsConfig = ws.TLSClientConfig
}
dial = func(network, address string) (net.Conn, error) {
return tls.Dial("tcp", target, tlsConfig)
}
}
d, err := dial("tcp", target)
if err != nil {
http.Error(ws.RW, "Error contacting backend server.", 500)
log.WithFields(logrus.Fields{
"path": target,
"origin": GetIPFromRequest(req),
}).Error("Error dialing websocket backend", target, ": ", err)
return nil, err
}
defer d.Close()
hj, ok := ws.RW.(http.Hijacker)
if !ok {
http.Error(ws.RW, "Not a hijacker?", 500)
return nil, errors.New("Not a hjijacker?")
}
nc, _, err := hj.Hijack()
if err != nil {
log.WithFields(logrus.Fields{
"path": req.URL.Path,
"origin": GetIPFromRequest(req),
}).Errorf("Hijack error: %v", err)
return nil, err
}
defer nc.Close()
if err := req.Write(d); err != nil {
log.WithFields(logrus.Fields{
"path": req.URL.Path,
"origin": GetIPFromRequest(req),
}).Errorf("Error copying request to target: %v", err)
return nil, err
}
errc := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errc <- err
}
go cp(d, nc)
go cp(nc, d)
for i := 0; i < 2; i++ {
cerr := <-errc
if cerr == nil {
continue
}
err = cerr
log.WithFields(logrus.Fields{
"path": req.URL.Path,
"origin": GetIPFromRequest(req),
}).Errorf("Error transmitting request: %v", err)
}
return nil, err
}
func IsWebsocket(req *http.Request) bool {
if !globalConf.HttpServerOptions.EnableWebSockets {
return false
}
connection := strings.ToLower(strings.TrimSpace(req.Header.Get("Connection")))
if connection != "upgrade" {
return false
}
upgrade := strings.ToLower(strings.TrimSpace(req.Header.Get("Upgrade")))
return upgrade == "websocket"
}