From 0cff6ea1f979ab5298c34b6449b7b19dc191ac13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Cie=C5=9Blak?= Date: Fri, 17 May 2024 16:28:28 +0200 Subject: [PATCH] Introduce ProcessManager to simplify VNet's public API --- lib/teleterm/vnet/service.go | 115 ++++++++++------------------ lib/vnet/setup.go | 134 +++++++++++++++++++++++++++++---- lib/vnet/setup_darwin.go | 38 +++------- lib/vnet/setup_other.go | 2 +- lib/vnet/vnet.go | 4 +- lib/vnet/vnet_test.go | 2 +- tool/tsh/common/vnet_darwin.go | 40 +++------- 7 files changed, 180 insertions(+), 155 deletions(-) diff --git a/lib/teleterm/vnet/service.go b/lib/teleterm/vnet/service.go index 06c1050c12f27..cc144a41282f7 100644 --- a/lib/teleterm/vnet/service.go +++ b/lib/teleterm/vnet/service.go @@ -24,7 +24,6 @@ import ( "sync" "github.com/gravitational/trace" - "golang.org/x/sync/errgroup" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" @@ -50,18 +49,10 @@ const ( type Service struct { api.UnimplementedVnetServiceServer - cfg Config - mu sync.Mutex - status status - // stopErrC is used to pass an error from goroutine that runs VNet in the background to the - // goroutine which handles RPC for stopping VNet. stopErrC gets closed after VNet stops. Starting - // VNet creates a new channel and assigns it as stopErrC. - // - // It's a buffered channel in case VNet crashes and there's no Stop RPC reading from stopErrC at - // that moment. - stopErrC chan error - // cancel stops the VNet instance running in a separate goroutine. - cancel context.CancelFunc + cfg Config + mu sync.Mutex + status status + processManager *vnet.ProcessManager } // New creates an instance of Service. @@ -106,62 +97,34 @@ func (s *Service) Start(ctx context.Context, req *api.StartRequest) (*api.StartR return &api.StartResponse{}, nil } - socket, socketPath, err := vnet.CreateSocket(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - longCtx, cancelLongCtx := context.WithCancel(context.Background()) - s.cancel = cancelLongCtx - defer func() { - // If by the end of this RPC the service is not running, make sure to cancel the long context. - if s.status != statusRunning { - cancelLongCtx() - } - }() - - g, longCtx := errgroup.WithContext(longCtx) - - g.Go(func() error { - <-longCtx.Done() - - return trace.Wrap(socket.Close()) - }) - - ipv6Prefix, err := vnet.IPv6Prefix() - if err != nil { - return nil, trace.Wrap(err) - } - dnsIPv6 := vnet.Ipv6WithSuffix(ipv6Prefix, []byte{2}) - - g.Go(func() error { - return trace.Wrap(vnet.ExecAdminSubcommand(longCtx, socketPath, ipv6Prefix.String(), dnsIPv6.String())) - }) - appProvider := &appProvider{ daemonService: s.cfg.DaemonService, clientStore: s.cfg.ClientStore, insecureSkipVerify: s.cfg.InsecureSkipVerify, } - ns, err := vnet.Setup(ctx, appProvider, socket, ipv6Prefix, dnsIPv6) + processManager, err := vnet.SetupAndRun(ctx, appProvider) if err != nil { return nil, trace.Wrap(err) } - - g.Go(func() error { - return trace.Wrap(ns.Run(longCtx)) - }) - - s.stopErrC = make(chan error, 1) + defer func() { + if s.status != statusRunning { + err := processManager.Close() + if err != nil && !errors.Is(err, context.Canceled) { + log.ErrorContext(ctx, "VNet closed with an error", "error", err) + } else { + log.DebugContext(ctx, "VNet closed") + } + } + }() go func() { - err := g.Wait() + err := processManager.Wait() if err != nil && !errors.Is(err, context.Canceled) { - log.ErrorContext(longCtx, "VNet closed with an error", "error", err) - s.stopErrC <- err + log.ErrorContext(ctx, "VNet closed with an error", "error", err) + } else { + log.DebugContext(ctx, "VNet closed") } - close(s.stopErrC) // TODO(ravicious): Notify the Electron app about change of VNet state, but only if it's // running. If it's not running, then the Start RPC has already failed and forwarded the error @@ -170,9 +133,12 @@ func (s *Service) Start(ctx context.Context, req *api.StartRequest) (*api.StartR s.mu.Lock() defer s.mu.Unlock() - s.status = statusNotRunning + if s.status == statusRunning { + s.status = statusNotRunning + } }() + s.processManager = processManager s.status = statusRunning return &api.StartResponse{}, nil } @@ -182,23 +148,12 @@ func (s *Service) Stop(ctx context.Context, req *api.StopRequest) (*api.StopResp s.mu.Lock() defer s.mu.Unlock() - errC := make(chan error) - - go func() { - errC <- trace.Wrap(s.stopLocked()) - }() - - select { - case <-ctx.Done(): - return nil, trace.Wrap(ctx.Err()) - case err := <-errC: - if err != nil { - return nil, trace.Wrap(err) - } - - return &api.StopResponse{}, nil + err := s.stopLocked() + if err != nil { + return nil, trace.Wrap(err) } + return &api.StopResponse{}, nil } func (s *Service) stopLocked() error { @@ -210,10 +165,13 @@ func (s *Service) stopLocked() error { return nil } - s.cancel() - s.status = statusNotRunning + err := s.processManager.Close() + if err != nil && !errors.Is(err, context.Canceled) { + return trace.Wrap(err) + } - return trace.Wrap(<-s.stopErrC) + s.status = statusNotRunning + return nil } // Close stops VNet service and prevents it from being started again. Blocks until VNet stops. @@ -223,9 +181,12 @@ func (s *Service) Close() error { defer s.mu.Unlock() err := s.stopLocked() - s.status = statusClosed + if err != nil { + return trace.Wrap(err) + } - return trace.Wrap(err) + s.status = statusClosed + return nil } type appProvider struct { diff --git a/lib/vnet/setup.go b/lib/vnet/setup.go index b1fc37d16676e..0fe16b25b7eeb 100644 --- a/lib/vnet/setup.go +++ b/lib/vnet/setup.go @@ -19,39 +19,95 @@ package vnet import ( "context" "log/slog" - "net" "os" "time" "github.com/gravitational/trace" + "golang.org/x/sync/errgroup" "golang.zx2c4.com/wireguard/tun" - "gvisor.dev/gvisor/pkg/tcpip" "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" ) -// CreateSocket creates a socket that's going to be used to receive the TUN device created by the -// admin subcommand. The admin subcommand quits when it detects that the socket has been closed. -func CreateSocket(ctx context.Context) (*net.UnixListener, string, error) { - socket, socketPath, err := createUnixSocket() +// SetupAndRun creates a network stack for VNet and runs it in the background. To do this, it also +// needs to launch an admin subcommand in the background. It returns [ProcessManager] which controls +// the lifecycle of both background tasks. +// +// The caller is expected to call Close on the process manager to close the network stack and clean +// up any resources used by it. +// +// ctx is used to wait for setup steps that happen before SetupAndRun hands out the control to the +// process manager. If ctx gets canceled during SetupAndRun, the process manager gets closed along +// with its background tasks. +func SetupAndRun(ctx context.Context, appProvider AppProvider) (*ProcessManager, error) { + ipv6Prefix, err := IPv6Prefix() if err != nil { - return nil, "", trace.Wrap(err) + return nil, trace.Wrap(err) } - slog.DebugContext(ctx, "Created unix socket for admin subcommand", "socket", socketPath) - return socket, socketPath, nil -} + dnsIPv6 := Ipv6WithSuffix(ipv6Prefix, []byte{2}) -// TODO: Add comment. -func Setup(ctx context.Context, appProvider AppProvider, socket *net.UnixListener, ipv6Prefix, dnsIPv6 tcpip.Address) (*NetworkStack, error) { - tun, err := receiveTUNDevice(ctx, socket) + pm := newProcessManager() + success := false + defer func() { + if !success { + // Closes the socket and background tasks. + pm.Close() + } + }() + + // Create the socket that's used to receive the TUN device from the admin subcommand. + socket, socketPath, err := createUnixSocket() if err != nil { return nil, trace.Wrap(err) } + slog.DebugContext(ctx, "Created unix socket for admin subcommand", "socket", socketPath) + pm.AddBackgroundTask(func(ctx context.Context) error { + <-ctx.Done() + return trace.Wrap(socket.Close()) + }) + + // A channel to capture an error when waiting for a TUN device to be created. + // + // To create a TUN device, VNet first needs to start the admin subcommand. When the subcommand + // starts, osascript shows a password prompt. If the user closes this prompt, execAdminSubcommand + // fails and the socket ends up being closed. To make sure that the user sees the error from + // osascript about prompt being closed instead of an error from receiveTUNDevice about reading + // from a closed socket, we send the error from osascript immediately through this channel, rather + // than depending on pm.Wait. + tunOrAdminSubcommandErrC := make(chan error, 2) + var tun tun.Device + + pm.AddBackgroundTask(func(ctx context.Context) error { + err := execAdminSubcommand(ctx, socketPath, ipv6Prefix.String(), dnsIPv6.String()) + // Pass the osascript error immediately, without having to wait on pm to propagate the error. + tunOrAdminSubcommandErrC <- trace.Wrap(err) + return trace.Wrap(err) + }) + + go func() { + tunDevice, err := receiveTUNDevice(socket) + tun = tunDevice + tunOrAdminSubcommandErrC <- err + }() + + select { + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err()) + case err := <-tunOrAdminSubcommandErrC: + if err != nil { + return nil, trace.Wrap(err) + } + + if tun == nil { + // If the execution ever gets there, it's because of a bug. + return nil, trace.Errorf("no TUN device created, execAdminSubcommand must have returned early with no error") + } + } appResolver := NewTCPAppResolver(appProvider) - ns, err := NewNetworkStack(&Config{ + ns, err := newNetworkStack(&Config{ TUNDevice: tun, IPv6Prefix: ipv6Prefix, DNSIPv6: dnsIPv6, @@ -61,7 +117,55 @@ func Setup(ctx context.Context, appProvider AppProvider, socket *net.UnixListene return nil, trace.Wrap(err) } - return ns, nil + pm.AddBackgroundTask(func(ctx context.Context) error { + return trace.Wrap(ns.Run(ctx)) + }) + + success = true + return pm, nil +} + +func newProcessManager() *ProcessManager { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + return &ProcessManager{ + g: g, + ctx: ctx, + cancel: cancel, + } +} + +// ProcessManager handles background tasks needed to run VNet. +// Its semantics are similar to an error group with context. +type ProcessManager struct { + g *errgroup.Group + ctx context.Context + cancel context.CancelFunc +} + +// AddBackgroundTask adds a function to the error group. The context passed to bgTaskFunc is a +// background context that gets canceled when Close is called or any added bgTaskFunc returns an +// error. +func (pm *ProcessManager) AddBackgroundTask(bgTaskFunc func(ctx context.Context) error) { + pm.g.Go(func() error { + return trace.Wrap(bgTaskFunc(pm.ctx)) + }) +} + +// Wait blocks and waits for the background tasks to finish, which typically happens when another +// goroutine calls Close on the process manager. +func (pm *ProcessManager) Wait() error { + return trace.Wrap(pm.g.Wait()) +} + +// Close stops any active background tasks by canceling the underlying context. It then returns the +// error from the error group. +func (pm *ProcessManager) Close() error { + go func() { + pm.cancel() + }() + return trace.Wrap(pm.g.Wait()) } // AdminSubcommand is the tsh subcommand that should run as root that will create and setup a TUN device and diff --git a/lib/vnet/setup_darwin.go b/lib/vnet/setup_darwin.go index 9efc1f43c20c9..0e9553cfc8029 100644 --- a/lib/vnet/setup_darwin.go +++ b/lib/vnet/setup_darwin.go @@ -38,8 +38,10 @@ import ( "github.com/gravitational/teleport" ) -func receiveTUNDevice(ctx context.Context, socket *net.UnixListener) (tun.Device, error) { - tunName, tunFd, err := recvTUNNameAndFd(ctx, socket) +// receiveTUNDevice is a blocking call which waits for the admin subcommand to pass over the socket +// the name and fd of the TUN device. +func receiveTUNDevice(socket *net.UnixListener) (tun.Device, error) { + tunName, tunFd, err := recvTUNNameAndFd(socket) if err != nil { return nil, trace.Wrap(err, "receiving TUN name and file descriptor") } @@ -48,7 +50,7 @@ func receiveTUNDevice(ctx context.Context, socket *net.UnixListener) (tun.Device return tunDevice, trace.Wrap(err, "creating TUN device from file descriptor") } -func ExecAdminSubcommand(ctx context.Context, socketPath, ipv6Prefix, dnsAddr string) error { +func execAdminSubcommand(ctx context.Context, socketPath, ipv6Prefix, dnsAddr string) error { executableName, err := os.Executable() if err != nil { return trace.Wrap(err, "getting executable path") @@ -153,32 +155,12 @@ func sendTUNNameAndFd(socketPath, tunName string, fd uintptr) error { // recvTUNNameAndFd receives the name of a TUN device and its open file descriptor over a unix socket, meant // for passing the TUN from the root process which must create it to the user process. -func recvTUNNameAndFd(ctx context.Context, socket *net.UnixListener) (string, uintptr, error) { - var conn *net.UnixConn - errC := make(chan error, 1) - - go func() { - connection, err := socket.AcceptUnix() - conn = connection - errC <- err - }() - - select { - case <-ctx.Done(): - return "", 0, trace.Wrap(ctx.Err()) - case err := <-errC: - if err != nil { - return "", 0, trace.Wrap(err, "accepting connection on unix socket") - } +func recvTUNNameAndFd(socket *net.UnixListener) (string, uintptr, error) { + conn, err := socket.AcceptUnix() + if err != nil { + return "", 0, trace.Wrap(err, "accepting connection on unix socket") } - - // Close the connection early to unblock reads if the context is canceled. - ctx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - <-ctx.Done() - conn.Close() - }() + defer conn.Close() msg := make([]byte, 128) oob := make([]byte, unix.CmsgSpace(4)) // Fd is 4 bytes diff --git a/lib/vnet/setup_other.go b/lib/vnet/setup_other.go index 5e60d37949e6f..86d851d0f8ad7 100644 --- a/lib/vnet/setup_other.go +++ b/lib/vnet/setup_other.go @@ -50,6 +50,6 @@ func configureOS(ctx context.Context, cfg *osConfig) error { return trace.Wrap(ErrVnetNotImplemented) } -func ExecAdminSubcommand(ctx context.Context, socketPath, ipv6Prefix, dnsAddr string) error { +func execAdminSubcommand(ctx context.Context, socketPath, ipv6Prefix, dnsAddr string) error { return trace.Wrap(ErrVnetNotImplemented) } diff --git a/lib/vnet/vnet.go b/lib/vnet/vnet.go index 9ef59cdaf5272..66e7b9ad2a799 100644 --- a/lib/vnet/vnet.go +++ b/lib/vnet/vnet.go @@ -213,10 +213,10 @@ func newState() state { } } -// NewNetworkStack creates a new VNet network stack with the given configuration and root context. +// newNetworkStack creates a new VNet network stack with the given configuration and root context. // It takes ownership of [cfg.TUNDevice] and will handle closing it before Run() returns. Call Run() // on the returned network stack to start the VNet. -func NewNetworkStack(cfg *Config) (*NetworkStack, error) { +func newNetworkStack(cfg *Config) (*NetworkStack, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/vnet/vnet_test.go b/lib/vnet/vnet_test.go index 01ee3778f68fa..2b9458b5e5a83 100644 --- a/lib/vnet/vnet_test.go +++ b/lib/vnet/vnet_test.go @@ -115,7 +115,7 @@ func newTestPack(t *testing.T, ctx context.Context, appProvider AppProvider) *te tcpHandlerResolver := NewTCPAppResolver(appProvider) // Create the VNet and connect it to the other side of the TUN. - ns, err := NewNetworkStack(&Config{ + ns, err := newNetworkStack(&Config{ TUNDevice: tun2, IPv6Prefix: vnetIPv6Prefix, DNSIPv6: dnsIPv6, diff --git a/tool/tsh/common/vnet_darwin.go b/tool/tsh/common/vnet_darwin.go index fee872498b319..c6798616d328e 100644 --- a/tool/tsh/common/vnet_darwin.go +++ b/tool/tsh/common/vnet_darwin.go @@ -25,7 +25,6 @@ import ( "github.com/alecthomas/kingpin/v2" "github.com/gravitational/trace" - "golang.org/x/sync/errgroup" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/vnet" @@ -49,41 +48,20 @@ func (c *vnetCommand) run(cf *CLIConf) error { return trace.Wrap(err) } - ctx, cancel := context.WithCancel(cf.Context) - defer cancel() - g, ctx := errgroup.WithContext(ctx) - - socket, socketPath, err := vnet.CreateSocket(ctx) - if err != nil { - return trace.Wrap(err) - } - - g.Go(func() error { - <-ctx.Done() - - return trace.Wrap(socket.Close()) - }) - - ipv6Prefix, err := vnet.IPv6Prefix() - if err != nil { - return trace.Wrap(err) - } - dnsIPv6 := vnet.Ipv6WithSuffix(ipv6Prefix, []byte{2}) - - g.Go(func() error { - return trace.Wrap(vnet.ExecAdminSubcommand(ctx, socketPath, ipv6Prefix.String(), dnsIPv6.String())) - }) - - manager, err := vnet.Setup(ctx, appProvider, socket, ipv6Prefix, dnsIPv6) + processManager, err := vnet.SetupAndRun(cf.Context, appProvider) if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } return trace.Wrap(err) } - g.Go(func() error { - return trace.Wrap(manager.Run(ctx)) - }) + go func() { + <-cf.Context.Done() + _ = processManager.Close() + }() - err = g.Wait() + err = processManager.Wait() if errors.Is(err, context.Canceled) { return nil }