From 78c0abe65918cb797d3c011d0dfa6de1beb17e5e Mon Sep 17 00:00:00 2001 From: Martin Hebnes Pedersen Date: Thu, 23 Nov 2023 20:40:54 +0100 Subject: [PATCH] Add connection prehook for scripted node traversal Implemented by spawning a user-defined process communicating with the remote station over stdio. Issue #114 --- conn_prehook.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++++ connect.go | 12 ++++++ go.mod | 1 + go.sum | 2 + 4 files changed, 126 insertions(+) create mode 100644 conn_prehook.go diff --git a/conn_prehook.go b/conn_prehook.go new file mode 100644 index 00000000..3045cd3f --- /dev/null +++ b/conn_prehook.go @@ -0,0 +1,111 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "os" + "os/exec" + "time" + + "golang.org/x/sync/errgroup" +) + +type prehookConn struct { + net.Conn + br *bufio.Reader + + executable string + args []string +} + +func NewPrehookConn(conn net.Conn, executable string, args ...string) prehookConn { + return prehookConn{ + Conn: conn, + br: bufio.NewReader(conn), + executable: executable, + args: args, + } +} + +func (p prehookConn) Read(b []byte) (int, error) { return p.br.Read(b) } + +// Wait waits for the prehook process to exit, returning nil if the process +// terminated successfully (exit code 0). +func (p prehookConn) Wait(ctx context.Context) error { + cmd := exec.CommandContext(ctx, p.executable, p.args...) + cmd.WaitDelay = time.Second + cmd.Stderr = os.Stderr + cmd.Stdout = p.Conn + cmdStdin, err := cmd.StdinPipe() + if err != nil { + return err + } + + // Copy environment to the child process. Also include additional + // relevant variables: REMOTE_ADDR, LOCAL_ADDR and the output of the + // env command. + cmd.Env = append(append(os.Environ(), + "PAT_REMOTE_ADDR="+p.RemoteAddr().String(), + "PAT_LOCAL_ADDR="+p.LocalAddr().String(), + ), envAll()...) + + if err := cmd.Start(); err != nil { + return err + } + + g, ctx := errgroup.WithContext(ctx) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + g.Go(func() error { return p.forwardLines(ctx, cmdStdin) }) + g.Go(func() error { defer cancel(); return cmd.Wait() }) + return g.Wait() +} + +// forwardLines forwards data from to the spawned process line by line. +// +// The line delimiter is CR or LF, but to facilitate scripting we append LF if +// it's missing. +// +// Wait one second after each line, to give the process time to terminate +// before delivering the next line. +func (p prehookConn) forwardLines(ctx context.Context, w io.Writer) error { + // Copy the lines to stdout so the user can see what's going on. + stdinBuffered := bufio.NewWriter(io.MultiWriter(w, os.Stdout)) + defer stdinBuffered.Flush() + + var isPrefix bool // true if we're in the middle of a line + for { + if !isPrefix { + // A line was just terminated (or no data has been read yet). + // Flush and wait one second to check if the process + // exited. If not we assume it expects an upcoming line. + if err := stdinBuffered.Flush(); err != nil { + return fmt.Errorf("child process exited prematurely: %w", err) + } + select { + case <-ctx.Done(): + return nil + case <-time.After(time.Second): + } + } + + b, err := p.br.ReadByte() + if err != nil { + return err + } + stdinBuffered.WriteByte(b) + isPrefix = !(b == '\n' || b == '\r') + + // Make sure CR is always followed by LF. It's easier to deal with in scripts. + if b == '\r' { + stdinBuffered.WriteByte('\n') + // Peek to check if the next byte is the LF we just wrote, in which case discard it. + if peek, _ := p.br.Peek(1); len(peek) > 0 && peek[0] == '\n' { + p.br.Discard(1) + } + } + } +} diff --git a/connect.go b/connect.go index df8614eb..8f1f7f86 100644 --- a/connect.go +++ b/connect.go @@ -202,6 +202,18 @@ func Connect(connectStr string) (success bool) { return } + if exec := url.Params.Get("prehook"); exec != "" { + log.Println("Running prehook...") + prehookConn := NewPrehookConn(conn, exec, url.Params["prehook-param"]...) + if err := prehookConn.Wait(ctx); err != nil { + conn.Close() + log.Printf("Prehook script failed: %s", err) + return + } + log.Println("Prehook succeeded") + conn = prehookConn + } + err = exchange(conn, url.Target, false) if err != nil { log.Printf("Exchange failed: %s", err) diff --git a/go.mod b/go.mod index b5238483..7fbce5a8 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/pd0mz/go-maidenhead v1.0.0 github.com/peterh/liner v1.2.1 github.com/spf13/pflag v1.0.5 + golang.org/x/sync v0.5.0 ) require ( diff --git a/go.sum b/go.sum index 849ab08d..c4aca25f 100644 --- a/go.sum +++ b/go.sum @@ -82,6 +82,8 @@ golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d h1:LO7XpTYMwTqxjLcGWPijK3vRXg1aWdlNOVOHRq45d7c= golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210223212115-eede4237b368/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=