Skip to content

Commit

Permalink
refactor: introduce tun (#54)
Browse files Browse the repository at this point in the history
This is the seventh (and in a sense, last) commit in the series of
incremental refactoring of the current minivpn tree. With this package
we have all the needed layers to start reasoning about the complete
architecture.

TUN uses a similar strategy to the TLSBio in the tlssession package: it
uses channels to communicate with the layer below (the data channel),
and it buffers reads.

Reference issue: #47

---------

Co-authored-by: Simone Basso <[email protected]>
  • Loading branch information
ainghazal and bassosimone authored Jan 22, 2024
1 parent 6c0c4cd commit b90d50a
Show file tree
Hide file tree
Showing 4 changed files with 422 additions and 0 deletions.
3 changes: 3 additions & 0 deletions internal/tun/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// Package tun is the public interface for the minivpn application. It exposes a tun device interface
// where the user of the application can write to and read from.
package tun
129 changes: 129 additions & 0 deletions internal/tun/setup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package tun

import (
"github.com/ooni/minivpn/internal/controlchannel"
"github.com/ooni/minivpn/internal/datachannel"
"github.com/ooni/minivpn/internal/model"
"github.com/ooni/minivpn/internal/networkio"
"github.com/ooni/minivpn/internal/packetmuxer"
"github.com/ooni/minivpn/internal/reliabletransport"
"github.com/ooni/minivpn/internal/runtimex"
"github.com/ooni/minivpn/internal/session"
"github.com/ooni/minivpn/internal/tlssession"
"github.com/ooni/minivpn/internal/workers"
)

// connectChannel connects an existing channel (a "signal" in Qt terminology)
// to a nil pointer to channel (a "slot" in Qt terminology).
func connectChannel[T any](signal chan T, slot **chan T) {
runtimex.Assert(signal != nil, "signal is nil")
runtimex.Assert(slot == nil || *slot == nil, "slot or *slot aren't nil")
*slot = &signal
}

// startWorkers starts all the workers. See the [ARCHITECTURE]
// file for more information about the workers.
//
// [ARCHITECTURE]: https://github.com/ooni/minivpn/blob/main/ARCHITECTURE.md
func startWorkers(logger model.Logger, sessionManager *session.Manager,
tunDevice *TUN, conn networkio.FramingConn, options *model.Options) *workers.Manager {
// create a workers manager
workersManager := workers.NewManager()

// create the networkio service.
nio := &networkio.Service{
MuxerToNetwork: make(chan []byte, 1<<5),
NetworkToMuxer: nil, // ok
}

// create the packetmuxer service.
muxer := &packetmuxer.Service{
MuxerToReliable: nil, // ok
MuxerToData: nil, // ok
NotifyTLS: nil,
HardReset: make(chan any, 1),
DataOrControlToMuxer: make(chan *model.Packet),
MuxerToNetwork: nil, // ok
NetworkToMuxer: make(chan []byte),
}

// connect networkio and packetmuxer
connectChannel(nio.MuxerToNetwork, &muxer.MuxerToNetwork)
connectChannel(muxer.NetworkToMuxer, &nio.NetworkToMuxer)

// create the datachannel service.
datach := &datachannel.Service{
MuxerToData: make(chan *model.Packet),
DataOrControlToMuxer: nil, // ok
KeyReady: make(chan *session.DataChannelKey, 1),
TUNToData: tunDevice.tunDown,
DataToTUN: tunDevice.tunUp,
}

// connect the packetmuxer and the datachannel
connectChannel(datach.MuxerToData, &muxer.MuxerToData)
connectChannel(muxer.DataOrControlToMuxer, &datach.DataOrControlToMuxer)

// create the reliabletransport service.
rel := &reliabletransport.Service{
DataOrControlToMuxer: nil, // ok
ControlToReliable: make(chan *model.Packet),
MuxerToReliable: make(chan *model.Packet),
ReliableToControl: nil, // ok
}

// connect reliable service and packetmuxer.
connectChannel(rel.MuxerToReliable, &muxer.MuxerToReliable)
connectChannel(muxer.DataOrControlToMuxer, &rel.DataOrControlToMuxer)

// create the controlchannel service.
ctrl := &controlchannel.Service{
NotifyTLS: nil, // ok
ControlToReliable: nil, // ok
ReliableToControl: make(chan *model.Packet),
TLSRecordToControl: make(chan []byte),
TLSRecordFromControl: nil, // ok
}

// connect the reliable service and the controlchannel service
connectChannel(rel.ControlToReliable, &ctrl.ControlToReliable)
connectChannel(ctrl.ReliableToControl, &rel.ReliableToControl)

// create the tlssession service
tlsx := &tlssession.Service{
NotifyTLS: make(chan *model.Notification, 1),
KeyUp: nil,
TLSRecordUp: make(chan []byte),
TLSRecordDown: nil,
}

// connect the tlsstate service and the controlchannel service
connectChannel(tlsx.NotifyTLS, &ctrl.NotifyTLS)
connectChannel(tlsx.TLSRecordUp, &ctrl.TLSRecordFromControl)
connectChannel(ctrl.TLSRecordToControl, &tlsx.TLSRecordDown)

// connect tlsstate service and the datachannel service
connectChannel(datach.KeyReady, &tlsx.KeyUp)

// connect the muxer and the tlsstate service
connectChannel(tlsx.NotifyTLS, &muxer.NotifyTLS)

logger.Debugf("%T: %+v", nio, nio)
logger.Debugf("%T: %+v", muxer, muxer)
logger.Debugf("%T: %+v", rel, rel)
logger.Debugf("%T: %+v", ctrl, ctrl)
logger.Debugf("%T: %+v", tlsx, tlsx)

// start all the workers
nio.StartWorkers(logger, workersManager, conn)
muxer.StartWorkers(logger, workersManager, sessionManager)
rel.StartWorkers(logger, workersManager, sessionManager)
ctrl.StartWorkers(logger, workersManager, sessionManager)
datach.StartWorkers(logger, workersManager, sessionManager, options)
tlsx.StartWorkers(logger, workersManager, sessionManager, options)

// tell the packetmuxer that it should handshake ASAP
muxer.HardReset <- true

return workersManager
}
208 changes: 208 additions & 0 deletions internal/tun/tun.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package tun

import (
"bytes"
"context"
"errors"
"net"
"os"
"sync"
"time"

"github.com/apex/log"
"github.com/ooni/minivpn/internal/model"
"github.com/ooni/minivpn/internal/networkio"
"github.com/ooni/minivpn/internal/session"
)

var (
ErrInitializationTimeout = errors.New("timeout while waiting for TUN to start")
)

// StartTUN initializes and starts the TUN device over the vpn.
// If the passed context expires before the TUN device is ready,
func StartTUN(ctx context.Context, conn networkio.FramingConn, options *model.Options) (*TUN, error) {
// create a session
sessionManager, err := session.NewManager(log.Log)
if err != nil {
return nil, err
}

// create the TUN that will OWN the connection
tunnel := newTUN(log.Log, conn, sessionManager)

// start all the workers
workers := startWorkers(log.Log, sessionManager, tunnel, conn, options)
tunnel.whenDone(func() {
workers.StartShutdown()
workers.WaitWorkersShutdown()
})

// Await for the signal from the session manager to tell us we're ready to start accepting data.
// In practice, this means that we already have a valid TunnelInfo at this point
// (i.e., three way handshake has completed, and we have valid keys).

select {
case <-ctx.Done():
return nil, ErrInitializationTimeout
case <-sessionManager.Ready:
return tunnel, nil
}
}

// TUN allows to use channels to read and write. It also OWNS the underlying connection.
// TUN implements net.Conn
type TUN struct {
// ensure idempotency.
closeOnce sync.Once

// conn is the underlying connection.
conn networkio.FramingConn

// hangup is used to let methods know the connection is closed.
hangup chan any

// logger implements model.Logger
logger model.Logger

// network is the underlying network for the passed [networkio.FramingConn].
network string

// used to buffer reads from above.
readBuffer *bytes.Buffer

// readDeadline is used to set the read deadline.
readDeadline tunDeadline

// session is the session manager
session *session.Manager

// tunDown moves bytes down to the data channel.
tunDown chan []byte

// tunUp moves bytes up from the data channel.
tunUp chan []byte

// callback to be executed on shutdown.
whenDoneFn func()

// writeDeadline is used to set the write deadline.
writeDeadline tunDeadline
}

// newTUN creates a new TUN.
// This function TAKES OWNERSHIP of the conn.
func newTUN(logger model.Logger, conn networkio.FramingConn, session *session.Manager) *TUN {
return &TUN{
closeOnce: sync.Once{},
conn: conn,
hangup: make(chan any),
logger: logger,
network: conn.LocalAddr().Network(),
readBuffer: &bytes.Buffer{},
readDeadline: makeTUNDeadline(),
session: session,
tunDown: make(chan []byte),
tunUp: make(chan []byte, 10),
// this function is explicitely set empty so that we can safely use a callback even if not set.
whenDoneFn: func() {},
writeDeadline: makeTUNDeadline(),
}
}

// whenDone registers a callback to be called on shutdown.
// This is useful to propagate shutdown to workers.
func (t *TUN) whenDone(fn func()) {
t.whenDoneFn = fn
}

func (t *TUN) Close() error {
t.closeOnce.Do(func() {
close(t.hangup)
// We OWN the connection
t.conn.Close()
// execute any shutdown callback
t.whenDoneFn()
})
return nil
}

func (t *TUN) Read(data []byte) (int, error) {
for {
count, _ := t.readBuffer.Read(data)
if count > 0 {
// log.Printf("[tunbio] received %d bytes", len(data))
return count, nil
}
if isClosedChan(t.readDeadline.wait()) {
return 0, os.ErrDeadlineExceeded
}
select {
case extra := <-t.tunUp:
t.readBuffer.Write(extra)
case <-t.hangup:
return 0, net.ErrClosed
case <-t.readDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
}

func (t *TUN) Write(data []byte) (int, error) {
if isClosedChan(t.writeDeadline.wait()) {
return 0, os.ErrDeadlineExceeded
}
select {
case t.tunDown <- data:
return len(data), nil
case <-t.hangup:
return 0, net.ErrClosed
case <-t.writeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}

func (t *TUN) LocalAddr() net.Addr {
ip := t.session.TunnelInfo().IP
return &tunBioAddr{ip, t.network}
}

func (t *TUN) RemoteAddr() net.Addr {
gw := t.session.TunnelInfo().GW
return &tunBioAddr{gw, t.network}
}

func (t *TUN) SetDeadline(tm time.Time) error {
t.readDeadline.set(tm)
t.writeDeadline.set(tm)
return nil
}

func (t *TUN) SetReadDeadline(tm time.Time) error {
t.readDeadline.set(tm)
return nil
}

func (t *TUN) SetWriteDeadline(tm time.Time) error {
t.writeDeadline.set(tm)
return nil
}

// tunBioAddr is the type of address returned by [*TUN]
type tunBioAddr struct {
addr string
net string
}

var _ net.Addr = &tunBioAddr{}

// Network implements net.Addr. It returns the network
// for the underlying connection.
func (t *tunBioAddr) Network() string {
return t.net
}

// String implements net.Addr
func (t *tunBioAddr) String() string {
return t.addr
}
Loading

0 comments on commit b90d50a

Please sign in to comment.