Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: share STDIN across different commands on pre-push hook #732

Merged
merged 7 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion internal/lefthook/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ Run 'lefthook install' manually.`,
)

startTime := time.Now()
results := r.RunAll(ctx, sourceDirs)
results, runErr := r.RunAll(ctx, sourceDirs)
if runErr != nil {
return fmt.Errorf("failed to run the hook: %w", runErr)
}

if ctx.Err() != nil {
return errors.New("Interrupted")
Expand Down
49 changes: 49 additions & 0 deletions internal/lefthook/runner/cached_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package runner

import (
"bytes"
"io"
)

// cachedReader reads from the provided `io.Reader` until `io.EOF` and saves
// the read content into the inner buffer.
//
// After `io.EOF` it will be providing the read data again and again.
type cachedReader struct {
in io.Reader
useBuffer bool
buf []byte
reader *bytes.Reader
}

func NewCachedReader(in io.Reader) *cachedReader {
return &cachedReader{
in: in,
buf: []byte{},
reader: bytes.NewReader([]byte{}),
}
}

func (r *cachedReader) Read(p []byte) (int, error) {
if r.useBuffer {
n, err := r.reader.Read(p)
if err == io.EOF {
_, seekErr := r.reader.Seek(0, io.SeekStart)
if seekErr != nil {
panic(seekErr)
}

return n, err
}

return n, err
}

n, err := r.in.Read(p)
r.buf = append(r.buf, p[:n]...)
if err == io.EOF {
r.useBuffer = true
r.reader = bytes.NewReader(r.buf)
}
return n, err
}
24 changes: 24 additions & 0 deletions internal/lefthook/runner/cached_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package runner

import (
"bytes"
"io"
"testing"
)

func TestCachedReader(t *testing.T) {
testSlice := []byte("Some example string\nMultiline")

cachedReader := NewCachedReader(bytes.NewReader(testSlice))

for range 5 {
res, err := io.ReadAll(cachedReader)
if err != nil {
t.Errorf("unexpected err: %s", err)
}

if !bytes.Equal(res, testSlice) {
t.Errorf("expected %v to be equal to %v", res, testSlice)
}
}
}
9 changes: 3 additions & 6 deletions internal/lefthook/runner/exec/execute_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ type executeArgs struct {
interactive, useStdin bool
}

func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Writer) error {
var in io.Reader = nullReader{}
if opts.UseStdin {
in = os.Stdin
}
func (e CommandExecutor) Execute(ctx context.Context, opts Options, in io.Reader, out io.Writer) error {
if opts.Interactive && !isatty.IsTerminal(os.Stdin.Fd()) {
tty, err := os.Open("/dev/tty")
if err == nil {
Expand Down Expand Up @@ -72,9 +68,10 @@ func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Write
return nil
}

func (e CommandExecutor) RawExecute(ctx context.Context, command []string, out io.Writer) error {
func (e CommandExecutor) RawExecute(ctx context.Context, command []string, in io.Reader, out io.Writer) error {
cmd := exec.CommandContext(ctx, command[0], command[1:]...)

cmd.Stdin = in
cmd.Stdout = out
cmd.Stderr = os.Stderr

Expand Down
11 changes: 4 additions & 7 deletions internal/lefthook/runner/exec/execute_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ type executeArgs struct {
root string
}

func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Writer) error {
var in io.Reader = nullReader{}
if opts.UseStdin {
in = os.Stdin
}
func (e CommandExecutor) Execute(ctx context.Context, opts Options, in io.Reader, out io.Writer) error {
if opts.Interactive && !isatty.IsTerminal(os.Stdin.Fd()) {
tty, err := tty.Open()
if err == nil {
Expand Down Expand Up @@ -63,9 +59,10 @@ func (e CommandExecutor) Execute(ctx context.Context, opts Options, out io.Write
return nil
}

func (e CommandExecutor) RawExecute(ctx context.Context, command []string, out io.Writer) error {
cmd := exec.Command(command[0], command[1:]...)
func (e CommandExecutor) RawExecute(ctx context.Context, command []string, in io.Reader, out io.Writer) error {
cmd := exec.CommandContext(ctx, command[0], command[1:]...)

cmd.Stdin = in
cmd.Stdout = out
cmd.Stderr = os.Stderr

Expand Down
4 changes: 2 additions & 2 deletions internal/lefthook/runner/exec/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ type Options struct {
// Executor provides an interface for command execution.
// It is used here for testing purpose mostly.
type Executor interface {
Execute(ctx context.Context, opts Options, out io.Writer) error
RawExecute(ctx context.Context, command []string, out io.Writer) error
Execute(ctx context.Context, opts Options, in io.Reader, out io.Writer) error
RawExecute(ctx context.Context, command []string, in io.Reader, out io.Writer) error
}
10 changes: 0 additions & 10 deletions internal/lefthook/runner/exec/nullReader.go

This file was deleted.

15 changes: 15 additions & 0 deletions internal/lefthook/runner/null_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package runner

import "io"

// nullReader always returns `io.EOF`.
type nullReader struct{}

func NewNullReader() io.Reader {
return nullReader{}
}

// Implements io.Reader interface.
func (nullReader) Read(b []byte) (int, error) {
return 0, io.EOF
}
20 changes: 20 additions & 0 deletions internal/lefthook/runner/null_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package runner

import (
"bytes"
"io"
"testing"
)

func TestNullReader(t *testing.T) {
nullReader := NewNullReader()

res, err := io.ReadAll(nullReader)
if err != nil {
t.Errorf("unexpected err: %s", err)
}

if !bytes.Equal(res, []byte{}) {
t.Errorf("expected %v to be equal to %v", res, []byte{})
}
}
28 changes: 21 additions & 7 deletions internal/lefthook/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,19 @@ type Options struct {
type Runner struct {
Options

stdin io.Reader
partiallyStagedFiles []string
failed atomic.Bool
executor exec.Executor
}

func New(opts Options) *Runner {
return &Runner{
Options: opts,
Options: opts,

// Some hooks use STDIN for parsing data from Git. To allow multiple commands
// and scripts access the same Git data STDIN is cached via cachedReader.
stdin: NewCachedReader(os.Stdin),
executor: exec.CommandExecutor{},
}
}
Expand All @@ -79,16 +84,16 @@ type executable interface {

// RunAll runs scripts and commands.
// LFS hook is executed at first if needed.
func (r *Runner) RunAll(ctx context.Context, sourceDirs []string) []Result {
func (r *Runner) RunAll(ctx context.Context, sourceDirs []string) ([]Result, error) {
results := make([]Result, 0, len(r.Hook.Commands)+len(r.Hook.Scripts))

if err := r.runLFSHook(ctx); err != nil {
log.Error(err)
return results, err
}

if r.Hook.DoSkip(r.Repo.State()) {
r.logSkip(r.HookName, "hook setting")
return results
return results, nil
}

if !r.DisableTTY && !r.Hook.Follow {
Expand All @@ -113,7 +118,7 @@ func (r *Runner) RunAll(ctx context.Context, sourceDirs []string) []Result {

r.postHook()

return results
return results, nil
}

func (r *Runner) runLFSHook(ctx context.Context) error {
Expand Down Expand Up @@ -144,6 +149,7 @@ func (r *Runner) runLFSHook(ctx context.Context) error {
[]string{"git", "lfs", r.HookName},
r.GitArgs...,
),
r.stdin,
out,
)

Expand Down Expand Up @@ -490,6 +496,12 @@ func (r *Runner) run(ctx context.Context, opts exec.Options, follow bool) bool {
log.SetName(opts.Name)
defer log.UnsetName(opts.Name)

// If the command does not explicitly `use_stdin` no input will be provided.
var in io.Reader = NewNullReader()
if opts.UseStdin {
in = r.stdin
}

if (follow || opts.Interactive) && r.LogSettings.LogExecution() {
r.logExecute(opts.Name, nil, nil)

Expand All @@ -500,12 +512,14 @@ func (r *Runner) run(ctx context.Context, opts exec.Options, follow bool) bool {
out = io.Discard
}

err := r.executor.Execute(ctx, opts, out)
err := r.executor.Execute(ctx, opts, in, out)

return err == nil
}

out := bytes.NewBuffer(make([]byte, 0))
err := r.executor.Execute(ctx, opts, out)

err := r.executor.Execute(ctx, opts, in, out)

r.logExecute(opts.Name, err, out)

Expand Down
9 changes: 6 additions & 3 deletions internal/lefthook/runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (

type TestExecutor struct{}

func (e TestExecutor) Execute(_ctx context.Context, opts exec.Options, _out io.Writer) (err error) {
func (e TestExecutor) Execute(_ctx context.Context, opts exec.Options, _in io.Reader, _out io.Writer) (err error) {
if strings.HasPrefix(opts.Commands[0], "success") {
err = nil
} else {
Expand All @@ -31,7 +31,7 @@ func (e TestExecutor) Execute(_ctx context.Context, opts exec.Options, _out io.W
return
}

func (e TestExecutor) RawExecute(_ctx context.Context, _command []string, _out io.Writer) error {
func (e TestExecutor) RawExecute(_ctx context.Context, _command []string, _in io.Reader, _out io.Writer) error {
return nil
}

Expand Down Expand Up @@ -766,7 +766,10 @@ func TestRunAll(t *testing.T) {
}

t.Run(fmt.Sprintf("%d: %s", i, tt.name), func(t *testing.T) {
results := runner.RunAll(context.Background(), tt.sourceDirs)
results, err := runner.RunAll(context.Background(), tt.sourceDirs)
if err != nil {
t.Errorf("unexpected error %s", err)
}

var success, fail []Result
for _, result := range results {
Expand Down
Loading