Skip to content

Commit

Permalink
fix: command execution control fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nxtcoder17 committed Jan 26, 2025
1 parent f4e8bb1 commit ed882c7
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 48 deletions.
171 changes: 125 additions & 46 deletions pkg/executor/cmd-executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ import (
"syscall"
)

// type CommandGroup struct {
// Commands []func(context.Context) *exec.Cmd
// Parallel bool
// Sequential bool
// }

type CmdExecutor struct {
logger *slog.Logger
parentCtx context.Context
Expand All @@ -19,12 +25,20 @@ type CmdExecutor struct {
mu sync.Mutex

kill func() error

Parallel []ParallelCommands
}

type ParallelCommands struct {
Index int
Len int
}

type CmdExecutorArgs struct {
Logger *slog.Logger
Commands []func(context.Context) *exec.Cmd
Interactive bool
Parallel []ParallelCommands
}

func NewCmdExecutor(ctx context.Context, args CmdExecutorArgs) *CmdExecutor {
Expand All @@ -38,6 +52,7 @@ func NewCmdExecutor(ctx context.Context, args CmdExecutorArgs) *CmdExecutor {
commands: args.Commands,
mu: sync.Mutex{},
interactive: args.Interactive,
Parallel: args.Parallel,
}
}

Expand Down Expand Up @@ -67,72 +82,136 @@ func killPID(pid int, logger ...*slog.Logger) error {
return nil
}

// Start implements Executor.
func (ex *CmdExecutor) Start() error {
ex.mu.Lock()
defer ex.mu.Unlock()
for i := range ex.commands {
if err := ex.parentCtx.Err(); err != nil {
return err
}
func (ex *CmdExecutor) exec(newCmd func(context.Context) *exec.Cmd) error {
if err := ex.parentCtx.Err(); err != nil {
return err
}

ctx, cf := context.WithCancel(ex.parentCtx)
defer cf()
ctx, cf := context.WithCancel(ex.parentCtx)
defer cf()

cmd := ex.commands[i](ctx)
cmd := newCmd(ctx)

cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
if ex.interactive {
cmd.Stdin = os.Stdin
cmd.SysProcAttr.Foreground = true
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
if ex.interactive {
cmd.Stdin = os.Stdin
cmd.SysProcAttr.Foreground = true
}

if err := cmd.Start(); err != nil {
return err
}
if err := cmd.Start(); err != nil {
return err
}

logger := ex.logger.With("pid", cmd.Process.Pid, "command", i+1)
logger := ex.logger.With("pid", cmd.Process.Pid, "command", cmd.String())

ex.kill = func() error {
return killPID(cmd.Process.Pid, logger)
}
ex.kill = func() error {
return killPID(cmd.Process.Pid, logger)
}

exitErr := make(chan error, 1)

go func() {
if err := cmd.Wait(); err != nil {
logger.Debug("process finished (wait completed), got", "err", err)
go func() {
if err := cmd.Wait(); err != nil {
exitErr <- err
logger.Debug("process finished (wait completed), got", "err", err)
}
cf()
}()

select {
case <-ctx.Done():
logger.Debug("process finished (context cancelled)")
case err := <-exitErr:
if exitErr, ok := err.(*exec.ExitError); ok {
logger.Debug("process finished", "exit.code", exitErr.ExitCode())
if exitErr.ExitCode() != 0 {
return err
}
cf()
}()

select {
case <-ctx.Done():
logger.Debug("process finished (context cancelled)")
case <-ex.parentCtx.Done():
logger.Debug("process finished (parent context cancelled)")
}
case <-ex.parentCtx.Done():
logger.Debug("process finished (parent context cancelled)")
}

if ex.interactive {
// Send SIGTERM to the interactive process, as user will see it on his screen
proc, err := os.FindProcess(os.Getpid())
if err != nil {
return err
}

if ex.interactive {
// Send SIGTERM to the interactive process, as user will see it on his screen
proc, err := os.FindProcess(os.Getpid())
if err != nil {
err = proc.Signal(syscall.SIGTERM)
if err != nil {
if err != syscall.ESRCH {
logger.Error("failed to kill, got", "err", err)
return err
}
return err
}
}

if err := ex.kill(); err != nil {
return err
}

err = proc.Signal(syscall.SIGTERM)
if err != nil {
if err != syscall.ESRCH {
logger.Error("failed to kill, got", "err", err)
return err
logger.Debug("command fully executed and processed")
return nil
}

// Start implements Executor.
func (ex *CmdExecutor) Start() error {
ex.mu.Lock()
defer ex.mu.Unlock()

var wg sync.WaitGroup

for i := 0; i < len(ex.commands); i++ {
newCmd := ex.commands[i]

ex.logger.Info("HELLO", "idx", i, "ex.parallel", ex.Parallel)
isParallel := false

for _, p := range ex.Parallel {
if p.Index == i {
isParallel = true
for k := i; k <= i+p.Len; k++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := ex.exec(newCmd); err != nil {
ex.logger.Info("executing, got", "err", err)
// handle error
}
}()
}
return err

i = i + p.Len - 1
}
break
}

if err := ex.kill(); err != nil {
if isParallel {
continue
}

// if ex.Parallel {
// wg.Add(1)
// go func() {
// defer wg.Add(1)
// if err := ex.exec(newCmd); err != nil {
// // handle error
// }
// }()
// continue
// }

if err := ex.exec(newCmd); err != nil {
ex.logger.Error("cmd failed with", "err", err)
return err
}
}

logger.Debug("command fully executed and processed")
if len(ex.Parallel) > 0 {
wg.Wait()
}

return nil
Expand Down
9 changes: 7 additions & 2 deletions pkg/watcher/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ func NewWatcher(ctx context.Context, args WatcherArgs) (*Watcher, error) {
}

excludeDirs := map[string]struct{}{}

for _, dir := range args.IgnoreDirs {
if args.ShouldLogWatchEvents {
args.Logger.Debug("EXCLUDED from watching", "dir", dir)
Expand All @@ -282,8 +283,12 @@ func NewWatcher(ctx context.Context, args WatcherArgs) (*Watcher, error) {
}

for _, dir := range args.WatchDirs {
if strings.HasPrefix(dir, "-") {
excludeDirs[dir[1:]] = struct{}{}
if args.ShouldLogWatchEvents {
args.Logger.Debug("watch-dirs", "dir", dir)
}
d := filepath.Base(dir)
if strings.HasPrefix(d, "-") {
excludeDirs[d[1:]] = struct{}{}
}
}

Expand Down

0 comments on commit ed882c7

Please sign in to comment.