-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathserver.go
278 lines (239 loc) · 7.91 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
package mwgp
import (
"errors"
"fmt"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.zx2c4.com/wireguard/device"
"log"
"net"
"strings"
"time"
)
type ServerConfigPeer struct {
ForwardTo string `json:"forward_to"`
forwardToAddress *net.UDPAddr
// ClientSourceValidateLevel is same config with the one in ServerConfigServer
// but intended to be used as a per-peer override.
ClientSourceValidateLevel int `json:"csvl,omitempty"`
// ServerSourceValidateLevel is same config with the one in ServerConfigServer
// but intended to be used as a per-peer override.
ServerSourceValidateLevel int `json:"ssvl,omitempty"`
ClientPublicKey *NoisePublicKey `json:"pubkey,omitempty"`
// required by cookie generator
serverPublicKey NoisePublicKey
}
func (p ServerConfigPeer) isFallback() bool {
return p.ClientPublicKey == nil
}
const (
SourceValidateLevelDefault = iota
// SourceValidateLevelNone (1):
// do not validate the source address.
// this allows client roaming but also comes with the risk of a kind of DoS attack.
// this is the default behavior for ClientSourceValidateLevel.
SourceValidateLevelNone
// SourceValidateLevelIP (2):
// validate the source address only by IP.
// disable client roaming across different hosts,
// maybe compatible with some kinds of NAT.
SourceValidateLevelIP
// SourceValidateLevelIPAndPort (3):
// validate the source address by IP and port.
// disabled the client roaming to defeat DoS attack,
// but client need to wait timeout and resend the MessageInitiation
// if they really got their IP address changed.
// this is the default behavior for ServerSourceValidateLevel.
SourceValidateLevelIPAndPort
)
type ServerConfigServer struct {
PrivateKey *NoisePrivateKey `json:"privkey"`
PrivateKeyFile string `json:"privkey_file,omitempty"`
Address string `json:"address"`
Peers []*ServerConfigPeer `json:"peers"`
// ClientSourceValidateLevel specified the way to handle a MessageTransport
// packet that comes from a source address not matches to prior packets.
ClientSourceValidateLevel int `json:"csvl,omitempty"`
// ServerSourceValidateLevel specified the way to handle a MessageTransport
// packet that comes from a source address not matches to prior packets.
ServerSourceValidateLevel int `json:"ssvl,omitempty"`
}
func (s *ServerConfigServer) Initialize() (err error) {
if len(s.Peers) == 0 {
err = fmt.Errorf("no peers")
return
}
if s.PrivateKey == nil {
if s.PrivateKeyFile == "" {
err = fmt.Errorf("no server private key provided")
return
}
s.PrivateKey = &NoisePrivateKey{}
err = s.PrivateKey.ReadFromFile(s.PrivateKeyFile)
if err != nil {
err = fmt.Errorf("cannot read private key from file %s: %w", s.PrivateKeyFile, err)
return
}
} else {
if s.PrivateKeyFile != "" {
err = fmt.Errorf("cannot specify both privkey and privkey_file")
return
}
}
var foundFallback bool
for pi, p := range s.Peers {
if p.ClientPublicKey == nil {
if foundFallback {
err = fmt.Errorf("multiple fallback peers found")
return
}
foundFallback = true
}
if len(p.ForwardTo) == 0 {
err = fmt.Errorf("peer[%d] has no forward_to address", pi)
return
}
forwardToTokens := strings.Split(p.ForwardTo, ":")
if len(forwardToTokens) != 2 {
err = fmt.Errorf("peer[%d] has invalid forward_to address %s", pi, p.ForwardTo)
return
}
address := strings.TrimSpace(forwardToTokens[0])
port := strings.TrimSpace(forwardToTokens[1])
if len(address) == 0 {
address = s.Address
}
forwardToAddress := strings.Join([]string{address, port}, ":")
p.forwardToAddress, err = net.ResolveUDPAddr("udp", forwardToAddress)
if err != nil {
err = fmt.Errorf("peer[%d] has invalid forward_to address %s: %w", pi, p.ForwardTo, err)
return
}
if p.ClientSourceValidateLevel == SourceValidateLevelDefault {
p.ClientSourceValidateLevel = s.ClientSourceValidateLevel
}
if p.ServerSourceValidateLevel == SourceValidateLevelDefault {
p.ServerSourceValidateLevel = s.ServerSourceValidateLevel
}
p.serverPublicKey = s.PrivateKey.PublicKey()
}
return
}
type ServerConfig struct {
Listen string `json:"listen"`
Timeout int `json:"timeout,omitempty"`
MaxPacketSize int `json:"max_packet_size,omitempty"`
Servers []*ServerConfigServer `json:"servers"`
ObfuscateKey string `json:"obfs"`
WGITCacheConfig
}
type Server struct {
wgitTable *WireGuardIndexTranslationTable
servers []*ServerConfigServer
}
func NewServerWithConfig(config *ServerConfig) (outServer *Server, err error) {
if len(config.Servers) == 0 {
err = errors.New("no server defined")
return
}
for si, s := range config.Servers {
err = s.Initialize()
if err != nil {
err = fmt.Errorf("server[%d]: %w", si, err)
return
}
}
server := Server{}
server.servers = config.Servers
server.wgitTable = NewWireGuardIndexTranslationTable()
server.wgitTable.ClientListen, err = net.ResolveUDPAddr("udp", config.Listen)
if err != nil {
err = fmt.Errorf("invalid listen address %s: %w", config.Listen, err)
return
}
if config.Timeout > 0 {
server.wgitTable.Timeout = time.Duration(config.Timeout) * time.Second
}
if config.MaxPacketSize > 0 {
server.wgitTable.MaxPacketSize = uint(config.MaxPacketSize)
}
server.wgitTable.ExtractPeerFunc = server.extractPeer
server.wgitTable.CacheJar.WGITCacheConfig = config.WGITCacheConfig
var obfuscator WireGuardObfuscator
obfuscator.Initialize(config.ObfuscateKey)
server.wgitTable.ClientWriteToUDPFunc = obfuscator.WriteToUDPWithObfuscate
server.wgitTable.ClientReadFromUDPFunc = obfuscator.ReadFromUDPWithDeobfuscate
outServer = &server
return
}
func (s *Server) extractPeer(msg *device.MessageInitiation) (sp *ServerConfigPeer, err error) {
tryDecryptPeerPKWith := func(privateKey NoisePrivateKey) (peerPK NoisePublicKey, err error) {
ourPublicKey := privateKey.PublicKey()
// most implementation here is copied from device.Device.ConsumeMessageInitiation().
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
devicex.mixHash(&hash, &device.InitialHash, ourPublicKey.NoisePublicKey[:])
devicex.mixHash(&hash, &hash, msg.Ephemeral[:])
devicex.mixKey(&chainKey, &device.InitialChainKey, msg.Ephemeral[:])
// decrypt static key
var key [chacha20poly1305.KeySize]byte
ss := privateKey.SharedSecret(msg.Ephemeral)
if devicex.isZero(ss[:]) {
return
}
device.KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK.NoisePublicKey[:0], device.ZeroNonce[:], msg.Static[:], hash[:])
if err != nil {
return
}
// TODO: now we have peerPK, but we can do further validation to protect against replay & flood
return
}
if len(s.servers) == 0 {
err = fmt.Errorf("no server configured")
return
}
var matchedServer *ServerConfigServer
var peerPK NoisePublicKey
for _, server := range s.servers {
peerPK, err = tryDecryptPeerPKWith(*server.PrivateKey)
if err == nil {
matchedServer = server
break
}
}
if err != nil {
err = fmt.Errorf("no server private key decrypted the message: %w", err)
return
}
var matchedServerPeer *ServerConfigPeer
var fallbackServerPeer *ServerConfigPeer
for _, peer := range matchedServer.Peers {
if peer.isFallback() {
fallbackServerPeer = peer
} else {
if peer.ClientPublicKey.Equals(peerPK.NoisePublicKey) {
matchedServerPeer = peer
}
}
}
if matchedServerPeer == nil {
matchedServerPeer = fallbackServerPeer
}
if matchedServerPeer == nil {
err = fmt.Errorf("no matched server peer and no fallback server peer for server %s", matchedServer.PrivateKey.Base64())
return
}
copiedPeer := *matchedServerPeer
copiedPeer.ClientPublicKey = &peerPK
sp = &copiedPeer
return
}
func (s *Server) Start() (err error) {
log.Printf("[info] listen on %s ...\n", s.wgitTable.ClientListen)
err = s.wgitTable.Serve()
return
}