Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: exec with context #207

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions script.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package script
import (
"bufio"
"container/ring"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
Expand Down Expand Up @@ -32,6 +33,7 @@ type Pipe struct {
stdout io.Writer
httpClient *http.Client

ctx context.Context
mu *sync.Mutex
err error
stderr io.Writer
Expand Down Expand Up @@ -166,6 +168,7 @@ func NewPipe() *Pipe {
return &Pipe{
Reader: ReadAutoCloser{},
mu: new(sync.Mutex),
ctx: context.Background(),
stdout: os.Stdout,
httpClient: http.DefaultClient,
env: nil,
Expand Down Expand Up @@ -423,7 +426,7 @@ func (p *Pipe) Exec(cmdLine string) *Pipe {
if err != nil {
return err
}
cmd := exec.Command(args[0], args[1:]...)
cmd := exec.CommandContext(p.ctx, args[0], args[1:]...)
cmd.Stdin = r
cmd.Stdout = w
cmd.Stderr = w
Expand Down Expand Up @@ -470,7 +473,7 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe {
if err != nil {
return err
}
cmd := exec.Command(args[0], args[1:]...)
cmd := exec.CommandContext(p.ctx, args[0], args[1:]...)
cmd.Stdout = w
cmd.Stderr = w
pipeStderr := p.stdErr()
Expand Down Expand Up @@ -974,6 +977,13 @@ func (p *Pipe) WithStdout(w io.Writer) *Pipe {
return p
}

// WithContext sets context.Context for the pipe. Adds support for graceful pipe
// shutdown. Currently works with [Pipe.Exec] and [Pipe.ExecForEach]
func (p *Pipe) WithContext(ctx context.Context) *Pipe {
p.ctx = ctx
return p
}
sk91 marked this conversation as resolved.
Show resolved Hide resolved

// WriteFile writes the pipe's contents to the file path, truncating it if it
// exists, and returns the number of bytes successfully written, or an error.
func (p *Pipe) WriteFile(path string) (int64, error) {
Expand Down
44 changes: 44 additions & 0 deletions script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package script_test
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
Expand All @@ -15,6 +16,7 @@ import (
"strings"
"testing"
"testing/iotest"
"time"

"github.com/bitfield/script"
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -1219,6 +1221,48 @@ func TestExecRunsGoHelpAndGetsUsageMessage(t *testing.T) {
}
}

func TestWithContextTimeout(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
p := script.NewPipe().WithContext(ctx).Exec("go run ./testdata/test_cli.go sleep 10")
p.Wait()
err := p.Error()
if err != nil && err.Error() != "signal: killed" {
t.Fatalf("context should timeout, %v", err)
}
t.Log(p.ExitStatus())
}

func TestWithContextTimeoutBeforeRun(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 0*time.Second)
defer cancel()
p := script.NewPipe().WithContext(ctx).Exec("go run ./testdata/test_cli.go sleep 10")
p.Wait()
err := p.Error()
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("context should timeout")
}
t.Log(p.ExitStatus())
}

func TestWithContextCancel(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
p := script.NewPipe().WithContext(ctx).Exec("go run ./testdata/test_cli.go sleep 2")
go func() {
<-time.After(1 * time.Second)
cancel()
}()
p.Wait()
err := p.Error()
if errors.Is(err, context.Canceled) {
t.Fatalf("context should cancel")
}
t.Log(p.ExitStatus())
}

func TestFileOutputsContentsOfSpecifiedFile(t *testing.T) {
t.Parallel()
want := "This is the first line in the file.\nHello, world.\nThis is another line in the file.\n"
Expand Down
67 changes: 67 additions & 0 deletions testdata/test_cli.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package main

import (
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
)

const commandSleep = "sleep"

func main() {
if len(os.Args) < 2 {
fmt.Println("No command provided")
usage()
os.Exit(1)
}

handleSignals()

switch strings.ToLower(os.Args[1]) {
case commandSleep:
if len(os.Args) != 3 {
fmt.Printf("Usage: %s %s <seconds>\n", os.Args[0], commandSleep)
usage()
os.Exit(1)
}
err := sleep(os.Args[2])
if err != nil {
fmt.Printf("Error sleeping: %s\n", err)
os.Exit(1)
}
default:
fmt.Printf("Unknown command: %s\n", os.Args[1])
usage()
os.Exit(1)
}
}

func usage() {
fmt.Printf("Usage of %s:\n", os.Args[0])
fmt.Printf(" sleep <seconds>\n")
}

func sleep(seconds string) error {
s, err := strconv.Atoi(seconds)
if err != nil {
return fmt.Errorf("sleep expects an integer, got %s", seconds)
}
<-time.After(time.Duration(s) * time.Second)
return nil
}

func handleSignals() {
sigs := make(chan os.Signal, 1)

signal.Notify(sigs, syscall.SIGINT)

go func() {
<-sigs
fmt.Println("\nReceived Ctrl+C, exiting...")
os.Exit(0)
}()
}