Skip to content

Commit

Permalink
Add support for Multipath TCP
Browse files Browse the repository at this point in the history
  • Loading branch information
neilalexander committed May 31, 2024
1 parent fec96a3 commit ed89915
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 42 deletions.
18 changes: 10 additions & 8 deletions src/admin/getpeers.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type PeerEntry struct {
PublicKey string `json:"key"`
Port uint64 `json:"port"`
Priority uint64 `json:"priority"`
Multipath bool `json:"multipath,omitempty"`
RXBytes DataUnit `json:"bytes_recvd,omitempty"`
TXBytes DataUnit `json:"bytes_sent,omitempty"`
Uptime float64 `json:"uptime,omitempty"`
Expand All @@ -37,14 +38,15 @@ func (a *AdminSocket) getPeersHandler(req *GetPeersRequest, res *GetPeersRespons
res.Peers = make([]PeerEntry, 0, len(peers))
for _, p := range peers {
peer := PeerEntry{
Port: p.Port,
Up: p.Up,
Inbound: p.Inbound,
Priority: uint64(p.Priority), // can't be uint8 thanks to gobind
URI: p.URI,
RXBytes: DataUnit(p.RXBytes),
TXBytes: DataUnit(p.TXBytes),
Uptime: p.Uptime.Seconds(),
Port: p.Port,
Up: p.Up,
Inbound: p.Inbound,
Priority: uint64(p.Priority), // can't be uint8 thanks to gobind
Multipath: p.Multipath,
URI: p.URI,
RXBytes: DataUnit(p.RXBytes),
TXBytes: DataUnit(p.TXBytes),
Uptime: p.Uptime.Seconds(),
}
if p.Latency > 0 {
peer.Latency = p.Latency
Expand Down
2 changes: 2 additions & 0 deletions src/core/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type PeerInfo struct {
Coords []uint64
Port uint64
Priority uint8
Multipath bool
RXBytes uint64
TXBytes uint64
Uptime time.Duration
Expand Down Expand Up @@ -87,6 +88,7 @@ func (c *Core) GetPeers() []PeerInfo {
peerinfo.RXBytes = atomic.LoadUint64(&c.rx)
peerinfo.TXBytes = atomic.LoadUint64(&c.tx)
peerinfo.Uptime = time.Since(c.up)
peerinfo.Multipath = isMPTCP(c)
}
if p, ok := conns[conn]; ok {
peerinfo.Key = p.Key
Expand Down
53 changes: 40 additions & 13 deletions src/core/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type links struct {

type linkProtocol interface {
dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error)
listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error)
listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error)
}

// linkInfo is used as a map key
Expand Down Expand Up @@ -72,6 +72,7 @@ type linkOptions struct {
tlsSNI string
password []byte
maxBackoff time.Duration
multipath bool
}

type Listener struct {
Expand Down Expand Up @@ -140,6 +141,7 @@ const ErrLinkPinnedKeyInvalid = linkError("pinned public key is invalid")
const ErrLinkPasswordInvalid = linkError("password is invalid")
const ErrLinkUnrecognisedSchema = linkError("link schema unknown")
const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid")
const ErrLinkMultipathInvalid = linkError("multipath invalid")

func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
var retErr error
Expand Down Expand Up @@ -193,6 +195,17 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
}
options.maxBackoff = d
}
if p := u.Query().Get("multipath"); p != "" {
switch p {
case "true", "1":
options.multipath = true
case "false", "0":
options.multipath = false
default:
retErr = ErrLinkMultipathInvalid
return
}
}
// SNI headers must contain hostnames and not IP addresses, so we must make sure
// that we do not populate the SNI with an IP literal. We do this by splitting
// the host-port combo from the query option and then seeing if it parses to an
Expand Down Expand Up @@ -379,7 +392,7 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
return retErr
}

func (l *links) remove(u *url.URL, sintf string, linkType linkType) error {
func (l *links) remove(u *url.URL, sintf string, _ linkType) error {
var retErr error
phony.Block(l, func() {
// Generate the link info and see whether we think we already
Expand Down Expand Up @@ -422,31 +435,45 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) {
cancel()
return nil, ErrLinkUnrecognisedSchema
}
listener, err := protocol.listen(ctx, u, sintf)
if err != nil {
cancel()
return nil, err
}
li := &Listener{
listener: listener,
ctx: ctx,
Cancel: cancel,
}

var options linkOptions
if p := u.Query().Get("priority"); p != "" {
pi, err := strconv.ParseUint(p, 10, 8)
if err != nil {
cancel()
return nil, ErrLinkPriorityInvalid
}
options.priority = uint8(pi)
}
if p := u.Query().Get("password"); p != "" {
if len(p) > blake2b.Size {
cancel()
return nil, ErrLinkPasswordInvalid
}
options.password = []byte(p)
}
if p := u.Query().Get("multipath"); p != "" {
switch p {
case "true", "1":
options.multipath = true
case "false", "0":
options.multipath = false
default:
cancel()
return nil, ErrLinkMultipathInvalid
}
}

listener, err := protocol.listen(ctx, u, sintf, options)
if err != nil {
cancel()
return nil, err
}
li := &Listener{
listener: listener,
ctx: ctx,
Cancel: cancel,
}

go func() {
l.core.log.Infof("%s listener started on %s", strings.ToUpper(u.Scheme), listener.Addr())
Expand Down Expand Up @@ -567,7 +594,7 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn, s
switch {
case err != nil:
return fmt.Errorf("write handshake: %w", err)
case err == nil && n != len(metaBytes):
case n != len(metaBytes):
return fmt.Errorf("incomplete handshake send")
}
meta = version_metadata{}
Expand Down
2 changes: 1 addition & 1 deletion src/core/link_quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (l *linkQUIC) dial(ctx context.Context, url *url.URL, info linkInfo, option
}, nil
}

func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
ql, err := quic.ListenAddr(url.Host, l.tlsconfig, l.quicconfig)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion src/core/link_socks.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options
return conn, nil
}

func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
return nil, fmt.Errorf("SOCKS listener not supported")
}
19 changes: 13 additions & 6 deletions src/core/link_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type tcpDialer struct {
addr *net.TCPAddr
}

func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) {
func (l *linkTCP) dialersFor(url *url.URL, info linkInfo, options linkOptions) ([]*tcpDialer, error) {
host, p, err := net.SplitHostPort(url.Host)
if err != nil {
return nil, err
Expand All @@ -55,7 +55,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error)
IP: ip,
Port: port,
}
dialer, err := l.dialerFor(addr, info.sintf)
dialer, err := l.dialerFor(addr, info.sintf, options.multipath)
if err != nil {
continue
}
Expand All @@ -69,7 +69,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error)
}

func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
dialers, err := l.dialersFor(url, info)
dialers, err := l.dialersFor(url, info, options)
if err != nil {
return nil, err
}
Expand All @@ -88,17 +88,21 @@ func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options
return nil, err
}

func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) {
func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) {
hostport := url.Host
if sintf != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil {
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
}
}
return l.listenconfig.Listen(ctx, "tcp", hostport)
lc := *l.listenconfig
if options.multipath {
setMPTCPForListener(&lc)
}
return lc.Listen(ctx, "tcp", hostport)
}

func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) {
func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string, mptcp bool) (*net.Dialer, error) {
if dst.IP.IsLinkLocalUnicast() {
if sintf != "" {
dst.Zone = sintf
Expand All @@ -112,6 +116,9 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error)
KeepAlive: -1,
Control: l.tcpContext,
}
if mptcp {
setMPTCPForDialer(dialer)
}
if sintf != "" {
dialer.Control = l.getControl(sintf)
ief, err := net.InterfaceByName(sintf)
Expand Down
30 changes: 30 additions & 0 deletions src/core/link_tcp_mptcp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package core

import (
"crypto/tls"
"net"
)

func setMPTCPForDialer(d *net.Dialer) {
d.SetMultipathTCP(true)
}

func setMPTCPForListener(lc *net.ListenConfig) {
lc.SetMultipathTCP(true)
}

func isMPTCP(c net.Conn) bool {
switch tc := c.(type) {
case *net.TCPConn:
mp, _ := tc.MultipathTCP()
return mp
case *tls.Conn:
if tc, ok := tc.NetConn().(*net.TCPConn); ok {
mp, _ := tc.MultipathTCP()
return mp
}
return false
default:
return false
}
}
16 changes: 4 additions & 12 deletions src/core/link_tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package core
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"

Expand Down Expand Up @@ -34,7 +33,7 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
}

func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
dialers, err := l.tcp.dialersFor(url, info)
dialers, err := l.tcp.dialersFor(url, info, options)
if err != nil {
return nil, err
}
Expand All @@ -58,17 +57,10 @@ func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options
return nil, err
}

func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) {
hostport := url.Host
if sintf != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil {
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
}
}
listener, err := l.listener.Listen(ctx, "tcp", hostport)
func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) {
listener, err := l.tcp.listen(ctx, url, sintf, options)
if err != nil {
return nil, err
}
tlslistener := tls.NewListener(listener, l.config)
return tlslistener, nil
return tls.NewListener(listener, l.config), nil
}
2 changes: 1 addition & 1 deletion src/core/link_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ func (l *linkUNIX) dial(ctx context.Context, url *url.URL, info linkInfo, option
return l.dialer.DialContext(ctx, "unix", addr.String())
}

func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
return l.listener.Listen(ctx, "unix", url.Path)
}

0 comments on commit ed89915

Please sign in to comment.