-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtunneld.go
140 lines (122 loc) · 3.46 KB
/
tunneld.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package tunneld
import (
"context"
"fmt"
"net"
"net/http"
"net/netip"
"sync"
"time"
"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun/netstack"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"golang.org/x/xerrors"
)
// TODO: add logging to API
type API struct {
*Options
wgNet *netstack.Net
wgDevice *device.Device
transport *http.Transport
pkeyCacheMu sync.RWMutex
pkeyCache map[netip.Addr]cachedPeer
}
type cachedPeer struct {
key device.NoisePublicKey
lastHandshake time.Time
}
func New(options *Options) (*API, error) {
if options == nil {
options = &Options{}
}
err := options.Validate()
if err != nil {
return nil, xerrors.Errorf("invalid options: %w", err)
}
// Create the wireguard virtual TUN adapter and netstack.
tun, wgNet, err := netstack.CreateNetTUN(
[]netip.Addr{options.WireguardServerIP},
// We don't do DNS resolution over the netstack, so don't specify any
// DNS servers.
[]netip.Addr{},
options.WireguardMTU,
)
if err != nil {
return nil, xerrors.Errorf("create wireguard virtual TUN adapter and netstack: %w", err)
}
// Create, configure and start the wireguard device.
deviceLogger := options.Log.Named("wireguard_device")
dlog := &device.Logger{
Verbosef: func(format string, args ...interface{}) {
deviceLogger.Debug(context.Background(), fmt.Sprintf(format, args...))
},
Errorf: func(format string, args ...interface{}) {
deviceLogger.Error(context.Background(), fmt.Sprintf(format, args...))
},
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), dlog)
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
listen_port=%d`,
options.WireguardKey.HexString(),
options.WireguardPort,
))
if err != nil {
return nil, xerrors.Errorf("configure wireguard device: %w", err)
}
err = dev.Up()
if err != nil {
return nil, xerrors.Errorf("start wireguard device: %w", err)
}
return &API{
Options: options,
wgNet: wgNet,
wgDevice: dev,
pkeyCache: make(map[netip.Addr]cachedPeer),
transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (nc net.Conn, err error) {
ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "(http.Transport).DialContext")
defer span.End()
defer func() {
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
}()
ip := ctx.Value(ipPortKey{})
if ip == nil {
err = xerrors.New("no ip on context")
return nil, err
}
ipp, ok := ip.(netip.AddrPort)
if !ok {
err = xerrors.Errorf("ip is incorrect type, got %T", ipp)
return nil, err
}
span.SetAttributes(attribute.String("wireguard_addr", ipp.Addr().String()))
dialCtx, dialCancel := context.WithTimeout(ctx, options.PeerDialTimeout)
defer dialCancel()
nc, err = wgNet.DialContextTCPAddrPort(dialCtx, ipp)
if err != nil {
return nil, err
}
return nc, nil
},
ForceAttemptHTTP2: false,
MaxIdleConns: 0,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}, nil
}
func (api *API) Close() error {
// Remove peers before closing to avoid a race condition between dev.Close()
// and the peer goroutines which results in segfault.
api.wgDevice.RemoveAllPeers()
api.wgDevice.Close()
<-api.wgDevice.Wait()
return nil
}