-
Notifications
You must be signed in to change notification settings - Fork 4
/
connection.go
386 lines (340 loc) · 10.2 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
// Copyright 2014 The zephyr-go authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package zephyr
import (
"errors"
"net"
"sync"
"time"
"github.com/zephyr-im/krb5-go"
)
func localIPForUDPAddr(addr *net.UDPAddr) (net.IP, error) {
bogus, err := net.DialUDP("udp", nil, addr)
if err != nil {
return nil, err
}
defer bogus.Close()
return bogus.LocalAddr().(*net.UDPAddr).IP, nil
}
func udpAddrsEqual(a, b *net.UDPAddr) bool {
return a.IP.Equal(b.IP) && a.Port == b.Port && a.Zone == b.Zone
}
// How frequently we query for new servers.
const serverRefreshInterval = 10 * time.Minute
// A Connection represents a low-level connection to the Zephyr
// servers. It handles server discovery and sending and receiving
// Notices. It does not provide high-level constructs like subscribing
// or message sharding. It also does not automatically send
// CLIENTACKs.
type Connection struct {
// Properties of the connection.
conn net.PacketConn
server ServerConfig
cred *krb5.Credential
clock Clock
localIP net.IP
// Incoming notices from the connection.
allNotices <-chan NoticeReaderResult
// Where non-ACK notices get dumped.
notices chan NoticeReaderResult
// Table of pending ACKs.
ackTable map[UID]chan NoticeReaderResult
ackTableLock sync.Mutex
// Current server send schedule.
sched []*net.UDPAddr
schedIdx int
schedLock sync.Mutex
stopRefreshing chan int
}
// NewConnection creates a new Connection wrapping a given
// net.PacketConn. The ServerConfig argument instructs the connection
// on how to locate the remote servers. The Credential is used to
// authenticate incoming and outgoing packets. The connection takes
// ownership of the PacketConn and will close it when Close is
// called.
func NewConnection(
conn net.PacketConn,
server ServerConfig,
cred *krb5.Credential,
) (*Connection, error) {
return NewConnectionFull(conn, server, cred, SystemClock)
}
// NewConnectionFull does the same as NewConnection but takes an
// additional Clock argument for testing.
func NewConnectionFull(
conn net.PacketConn,
server ServerConfig,
cred *krb5.Credential,
clock Clock,
) (*Connection, error) {
c := new(Connection)
c.conn = conn
c.server = server
c.cred = cred
c.clock = clock
var key *krb5.KeyBlock
if c.cred != nil {
key = c.cred.KeyBlock
}
c.allNotices = ReadNoticesFromServer(conn, key)
c.notices = make(chan NoticeReaderResult)
c.ackTable = make(map[UID]chan NoticeReaderResult)
c.stopRefreshing = make(chan int, 1)
if _, err := c.RefreshServer(); err != nil {
return nil, err
}
localIP, err := localIPForUDPAddr(c.sched[0])
if err != nil {
return nil, err
}
c.localIP = localIP
go c.readLoop()
// This is kinda screwy. Purely for testing purposes, ensure
// the first query on the clock happens by the time
// NewConnectionFull returns. MockClock is a little messy.
go c.refreshLoop(c.clock.After(serverRefreshInterval))
return c, nil
}
// Notices returns the incoming notices from the connection.
func (c *Connection) Notices() <-chan NoticeReaderResult {
return c.notices
}
// LocalAddr returns the local UDP address for the client when
// communicating with the Zephyr servers.
func (c *Connection) LocalAddr() *net.UDPAddr {
addr := c.conn.LocalAddr().(*net.UDPAddr)
addr.IP = c.localIP
return addr
}
// Credential returns the credential for this connection.
func (c *Connection) Credential() *krb5.Credential {
return c.cred
}
// Close closes the underlying connection.
func (c *Connection) Close() error {
c.stopRefreshing <- 0
return c.conn.Close()
}
func (c *Connection) readLoop() {
for r := range c.allNotices {
if r.Notice.Kind.IsServerACK() {
c.processServAck(r)
} else {
c.notices <- r
}
}
close(c.notices)
}
func (c *Connection) refreshLoop(after <-chan time.Time) {
for {
select {
case <-after:
c.RefreshServer()
after = c.clock.After(serverRefreshInterval)
case <-c.stopRefreshing:
return
}
}
}
func (c *Connection) findPendingSend(uid UID) chan NoticeReaderResult {
c.ackTableLock.Lock()
defer c.ackTableLock.Unlock()
if ps, ok := c.ackTable[uid]; ok {
delete(c.ackTable, uid)
return ps
}
return nil
}
func (c *Connection) addPendingSend(uid UID) <-chan NoticeReaderResult {
// Buffer one entry; if the ACK and timeout race, the
// sending thread should not lock up.
ackChan := make(chan NoticeReaderResult, 1)
c.ackTableLock.Lock()
defer c.ackTableLock.Unlock()
c.ackTable[uid] = ackChan
return ackChan
}
func (c *Connection) clearPendingSend(uid UID) {
c.ackTableLock.Lock()
defer c.ackTableLock.Unlock()
delete(c.ackTable, uid)
}
func (c *Connection) processServAck(r NoticeReaderResult) {
ps := c.findPendingSend(r.Notice.UID)
if ps != nil {
ps <- r
}
}
func (c *Connection) schedule() ([]*net.UDPAddr, int) {
c.schedLock.Lock()
defer c.schedLock.Unlock()
return c.sched, c.schedIdx
}
func (c *Connection) setSchedule(sched []*net.UDPAddr, schedIdx int) {
c.schedLock.Lock()
defer c.schedLock.Unlock()
c.sched = sched
c.schedIdx = schedIdx
}
func (c *Connection) goodServer(good *net.UDPAddr) {
c.schedLock.Lock()
defer c.schedLock.Unlock()
// Find the good server in the schedule and use it
// preferentially next time.
for i, addr := range c.sched {
if udpAddrsEqual(addr, good) {
c.schedIdx = i
return
}
}
}
// RefreshServer forces a manual refresh of the server schedule from
// the ServerConfig. This will be called periodically and when
// outgoing messages time out, so there should be little need to call
// this manually.
func (c *Connection) RefreshServer() ([]*net.UDPAddr, error) {
sched, err := c.server.ResolveServer()
if err != nil {
return nil, err
}
if len(sched) == 0 {
panic(sched)
}
c.setSchedule(sched, 0)
return sched, nil
}
// SendNotice sends an authenticated notice to the servers. If the
// notice expects an acknowledgement, it returns the SERVACK or
// SERVNAK notice from the server on success.
func (c *Connection) SendNotice(ctx *krb5.Context, n *Notice) (*Notice, error) {
pkt, err := n.EncodePacketForServer(ctx, c.cred)
if err != nil {
return nil, err
}
return c.SendPacket(pkt, n.Kind, n.UID)
}
// SendNoticeUnauth sends an unauthenticated notice to the servers. If
// the notice expects an acknowledgement, it returns the SERVACK or
// SERVNAK notice from the server on success.
func (c *Connection) SendNoticeUnauth(n *Notice) (*Notice, error) {
pkt := n.EncodePacketUnauth()
return c.SendPacket(pkt, n.Kind, n.UID)
}
// SendNoticeUnackedTo sends an unauthenticated and unacked notice to
// a given destination. This is used to send a CLIENTACK to a received
// notice.
func (c *Connection) SendNoticeUnackedTo(n *Notice, addr net.Addr) error {
pkt := n.EncodePacketUnauth()
return c.SendPacketUnackedTo(pkt, addr)
}
// ErrPacketTooLong is returned when a notice or packet exceeds the
// maximum Zephyr packet size.
var ErrPacketTooLong = errors.New("packet too long")
// ErrSendTimeout is returned if a send times out without
// acknowledgement from the server.
var ErrSendTimeout = errors.New("send timeout")
// SendPacketUnackedTo sends a raw packet to a given destination.
func (c *Connection) SendPacketUnackedTo(pkt []byte, addr net.Addr) error {
if len(pkt) > MaxPacketLength {
return ErrPacketTooLong
}
_, err := c.conn.WriteTo(pkt, addr)
return err
}
// TODO(davidben): We probably want to be more cleverer later. For
// now, follow a similar strategy to the real zhm, but use a much more
// aggressive rexmit schedule.
//
// Empirically, it seems to take 15-20ms for the zephyrds to ACK a
// notice.
var retrySchedule = []time.Duration{
100 * time.Millisecond,
100 * time.Millisecond,
250 * time.Millisecond,
500 * time.Millisecond,
1 * time.Second,
2 * time.Second,
4 * time.Second,
}
// If we've timed out 4 times, get a new server schedule.
const timeoutsBeforeRefresh = 4
// SendPacket sends a raw packet to the Zephyr servers. Based on kind
// and uid, it may wait for an acknowledgement. In that case, the
// SERVACK or SERVNAK notice will be returned. SendPacket rotates
// between the server instances and refreshes server list as necessary.
func (c *Connection) SendPacket(pkt []byte, kind Kind, uid UID) (*Notice, error) {
// TODO(davidben): Should we limit the number of packets
// in-flight as an ad-hoc congestion control?
if len(pkt) > MaxPacketLength {
return nil, ErrPacketTooLong
}
retryIdx := -1
timeout := c.clock.After(0)
// Listen for ACKs.
var ackChan <-chan NoticeReaderResult
var shouldClear bool
if kind.ExpectsServerACK() {
ackChan = c.addPendingSend(uid)
shouldClear = true
defer func() {
if shouldClear {
c.clearPendingSend(uid)
}
}()
}
// Get the remote server schedule.
sched, schedIdx := c.schedule()
if len(sched) == 0 {
panic(sched)
}
for {
select {
case ack := <-ackChan:
shouldClear = false // Already taken care of.
// Record the good server so next time we
// start at that one.
c.goodServer(ack.Addr.(*net.UDPAddr))
return ack.Notice, nil
case <-timeout:
retryIdx++
if retryIdx >= len(retrySchedule) {
return nil, ErrSendTimeout
}
// Partway through the re-xmit schedule, if we
// still haven't heard back from any server,
// get a fresh set of remote addresses.
if retryIdx == timeoutsBeforeRefresh {
var err error
sched, err = c.RefreshServer()
if err != nil {
return nil, err
}
schedIdx = 0
}
addr := sched[schedIdx]
if err := c.SendPacketUnackedTo(pkt, addr); err != nil {
// TODO(davidben): Keep going on
// temporary errors?
return nil, err
}
if !kind.ExpectsServerACK() {
return nil, nil
}
// Schedule the next timeout and move on to
// the next server.
timeout = c.clock.After(retrySchedule[retryIdx])
schedIdx = (schedIdx + 1) % len(sched)
}
}
}