diff --git a/cmd/dagger/shell.go b/cmd/dagger/shell.go index 3068119344..b393b3ccd1 100644 --- a/cmd/dagger/shell.go +++ b/cmd/dagger/shell.go @@ -244,26 +244,11 @@ func litWord(s string) *syntax.Word { // run parses code and executes the interpreter's Runner func (h *shellCallHandler) run(ctx context.Context, reader io.Reader, name string) error { - file, err := syntax.NewParser(syntax.Variant(syntax.LangPOSIX)).Parse(reader, name) + file, err := parseShell(reader, name) if err != nil { return err } - syntax.Walk(file, func(node syntax.Node) bool { - if node, ok := node.(*syntax.CmdSubst); ok { - // Rewrite command substitutions from $(foo; bar) to $(exec <&-; foo; bar) - // so that all the original commands run with a closed (nil) standard input. - node.Stmts = append([]*syntax.Stmt{{ - Cmd: &syntax.CallExpr{Args: []*syntax.Word{litWord("..exec")}}, - Redirs: []*syntax.Redirect{{ - Op: syntax.DplIn, - Word: litWord("-"), - }}, - }}, node.Stmts...) - } - return true - }) - h.stdoutBuf.Reset() h.stderrBuf.Reset() @@ -308,6 +293,29 @@ func (h *shellCallHandler) run(ctx context.Context, reader io.Reader, name strin }) } +func parseShell(reader io.Reader, name string) (*syntax.File, error) { + file, err := syntax.NewParser(syntax.Variant(syntax.LangPOSIX)).Parse(reader, name) + if err != nil { + return nil, err + } + + syntax.Walk(file, func(node syntax.Node) bool { + if node, ok := node.(*syntax.CmdSubst); ok { + // Rewrite command substitutions from $(foo; bar) to $(exec <&-; foo; bar) + // so that all the original commands run with a closed (nil) standard input. + node.Stmts = append([]*syntax.Stmt{{ + Cmd: &syntax.CallExpr{Args: []*syntax.Word{litWord("..exec")}}, + Redirs: []*syntax.Redirect{{ + Op: syntax.DplIn, + Word: litWord("-"), + }}, + }}, node.Stmts...) + } + return true + }) + return file, nil +} + // runPath executes code from a file func (h *shellCallHandler) runPath(ctx context.Context, path string) error { f, err := os.Open(path) @@ -423,6 +431,7 @@ func (h *shellCallHandler) loadReadlineConfig(prompt string) (*readline.Config, Prompt: prompt, HistoryFile: filepath.Join(dataRoot, "histfile"), HistoryLimit: 1000, + AutoComplete: &shellAutoComplete{h}, }, nil } @@ -1243,16 +1252,14 @@ type ShellCommand struct { // Expected arguments Args PositionalArgs - // Run is the function that will be executed if it's the first command - // in the pipeline and RunState is not defined - Run func(cmd *ShellCommand, args []string) error + // Expected state + State StateArg - // RunState is the function for executing a command that can be chained - // in a pipeline - // - // If defined, it's always used, even if it's the first command in the - // pipeline. For commands that should only be the first, define `Run` instead. - RunState func(cmd *ShellCommand, args []string, st *ShellState) error + // Run is the function that will be executed. + Run func(cmd *ShellCommand, args []string, st *ShellState) error + + // Complete provides builtin completions + Complete func(ctx *CompletionContext, args []string) *CompletionContext // HelpFunc is a custom function for customizing the help output HelpFunc func(cmd *ShellCommand) string @@ -1376,10 +1383,26 @@ func NoArgs(args []string) error { return nil } +type StateArg uint + +const ( + AnyState StateArg = iota + RequiredState + NoState +) + // Execute is the main dispatcher function for shell builtin commands func (c *ShellCommand) Execute(ctx context.Context, h *shellCallHandler, args []string, st *ShellState) error { - if st != nil && c.RunState == nil { - return fmt.Errorf("command %q cannot be piped", c.Name()) + switch c.State { + case AnyState: + case RequiredState: + if st == nil { + return fmt.Errorf("command %q must be piped\nusage: %s", c.Name(), c.Use) + } + case NoState: + if st != nil { + return fmt.Errorf("command %q cannot be piped\nusage: %s", c.Name(), c.Use) + } } if c.Args != nil { if err := c.Args(args); err != nil { @@ -1406,10 +1429,7 @@ func (c *ShellCommand) Execute(ctx context.Context, h *shellCallHandler, args [] shellDebug(ctx, "└ CmdExec(%v)", a) } c.SetContext(ctx) - if c.RunState != nil { - return c.RunState(c, a, st) - } - return c.Run(c, a) + return c.Run(c, a, st) } // shellFunctionUseLine returns the usage line fine for a function @@ -1878,7 +1898,8 @@ func (h *shellCallHandler) registerCommands() { //nolint:gocyclo Use: ".debug", Hidden: true, Args: NoArgs, - Run: func(_ *ShellCommand, _ []string) error { + State: NoState, + Run: func(cmd *ShellCommand, args []string, _ *ShellState) error { // Toggles debug mode, which can be useful when in interactive mode h.debug = !h.debug return nil @@ -1888,7 +1909,8 @@ func (h *shellCallHandler) registerCommands() { //nolint:gocyclo Use: ".help [command]", Description: "Print this help message", Args: MaximumArgs(1), - Run: func(cmd *ShellCommand, args []string) error { + State: NoState, + Run: func(cmd *ShellCommand, args []string, _ *ShellState) error { if len(args) == 1 { c, err := h.BuiltinCommand(args[0]) if err != nil { @@ -1935,7 +1957,7 @@ Local module paths are resolved relative to the workdir on the host, not relativ to the currently loaded module. `, Args: MaximumArgs(1), - RunState: func(cmd *ShellCommand, args []string, st *ShellState) error { + Run: func(cmd *ShellCommand, args []string, st *ShellState) error { var err error ctx := cmd.Context() @@ -2049,7 +2071,8 @@ to the currently loaded module. `, GroupID: moduleGroup.ID, Args: ExactArgs(1), - Run: func(cmd *ShellCommand, args []string) error { + State: NoState, + Run: func(cmd *ShellCommand, args []string, _ *ShellState) error { st, err := h.getOrInitDefState(args[0], func() (*moduleDef, error) { return initializeModule(cmd.Context(), h.dag, args[0], true) }) @@ -2069,29 +2092,52 @@ to the currently loaded module. Description: "Dependencies from the module loaded in the current context", GroupID: moduleGroup.ID, Args: NoArgs, - Run: func(cmd *ShellCommand, _ []string) error { + State: NoState, + Run: func(cmd *ShellCommand, _ []string, _ *ShellState) error { _, err := h.GetModuleDef(nil) if err != nil { return err } return cmd.Send(h.newDepsState()) }, + Complete: func(ctx *CompletionContext, _ []string) *CompletionContext { + return &CompletionContext{ + Completer: ctx.Completer, + CmdRoot: shellDepsCmdName, + root: true, + } + }, }, &ShellCommand{ Use: shellStdlibCmdName, Description: "Standard library functions", Args: NoArgs, - Run: func(cmd *ShellCommand, _ []string) error { + State: NoState, + Run: func(cmd *ShellCommand, _ []string, _ *ShellState) error { return cmd.Send(h.newStdlibState()) }, + Complete: func(ctx *CompletionContext, _ []string) *CompletionContext { + return &CompletionContext{ + Completer: ctx.Completer, + CmdRoot: shellStdlibCmdName, + root: true, + } + }, }, &ShellCommand{ Use: ".core [function]", Description: "Load any core Dagger type", - Args: NoArgs, - Run: func(cmd *ShellCommand, args []string) error { + State: NoState, + Run: func(cmd *ShellCommand, args []string, _ *ShellState) error { return cmd.Send(h.newCoreState()) }, + Complete: func(ctx *CompletionContext, _ []string) *CompletionContext { + return &CompletionContext{ + Completer: ctx.Completer, + CmdRoot: shellCoreCmdName, + root: true, + } + }, }, cobraToShellCommand(loginCmd), cobraToShellCommand(logoutCmd), @@ -2125,10 +2171,11 @@ to the currently loaded module. &ShellCommand{ Use: shellFunctionUseLine(def, fn), Description: fn.Description, + State: NoState, HelpFunc: func(cmd *ShellCommand) string { return shellFunctionDoc(def, fn) }, - Run: func(cmd *ShellCommand, args []string) error { + Run: func(cmd *ShellCommand, args []string, _ *ShellState) error { ctx := cmd.Context() st := h.newState() @@ -2139,6 +2186,13 @@ to the currently loaded module. return cmd.Send(st) }, + Complete: func(ctx *CompletionContext, args []string) *CompletionContext { + return &CompletionContext{ + Completer: ctx.Completer, + ModFunction: fn, + root: true, + } + }, }, ) } @@ -2160,7 +2214,8 @@ func cobraToShellCommand(c *cobra.Command) *ShellCommand { Use: "." + c.Use, Description: c.Short, GroupID: c.GroupID, - Run: func(cmd *ShellCommand, args []string) error { + State: NoState, + Run: func(cmd *ShellCommand, args []string, _ *ShellState) error { // Re-execute the dagger command (hack) args = append([]string{cmd.CleanName()}, args...) ctx := cmd.Context() diff --git a/cmd/dagger/shell_completion.go b/cmd/dagger/shell_completion.go new file mode 100644 index 0000000000..8c2e9b66a2 --- /dev/null +++ b/cmd/dagger/shell_completion.go @@ -0,0 +1,369 @@ +package main + +import ( + "slices" + "strings" + + "github.com/chzyer/readline" + "mvdan.cc/sh/v3/syntax" +) + +// shellAutoComplete is a wrapper for the shell call handler +type shellAutoComplete struct { + // This is separated out, since we don't want to have to attach all these + // methods to the shellCallHandler directly + *shellCallHandler +} + +var _ readline.AutoCompleter = (*shellAutoComplete)(nil) + +func (h *shellAutoComplete) Do(line []rune, pos int) (newLine [][]rune, length int) { + file, err := parseShell(strings.NewReader(string(line)), "") + if err != nil { + return nil, 0 + } + + // find the smallest stmt next to the cursor - this allows accurate + // completion inside subshells, for example + var stmt *syntax.Stmt + excluded := map[*syntax.Stmt]struct{}{} + syntax.Walk(file, func(node syntax.Node) bool { + switch node := node.(type) { + case *syntax.BinaryCmd: + if node.Op == syntax.Pipe { + // pipes are special, and those statements aren't atomic + // because they're chained off of the previous ones - so avoid + // isolating them + excluded[node.X] = struct{}{} + excluded[node.Y] = struct{}{} + } + case *syntax.Stmt: + if stmt == nil { + stmt = node + break + } + if pos < int(node.Pos().Offset()) || pos > int(node.End().Offset()) { + return false + } + if _, ok := excluded[node]; !ok { + stmt = node + } + } + return true + }) + + var inprogressWord *syntax.Word + syntax.Walk(file, func(node syntax.Node) bool { + if node, ok := node.(*syntax.Word); ok { + if node.End().Offset() == uint(pos) { + inprogressWord = node + return false + } + } + return true + }) + var inprogressPrefix string + if inprogressWord != nil { + inprogressPrefix = inprogressWord.Lit() + } + + // discard the in-progress word for the process of determining the + // auto-completion context (since it's likely to be invalid) + var cursor uint + if inprogressWord == nil { + cursor = uint(pos) + } else { + cursor = inprogressWord.Pos().Offset() + } + + shctx := h.root() + if stmt != nil { + shctx = h.dispatch(shctx, stmt, cursor) + } + if shctx == nil { + return nil, 0 + } + + var results [][]rune + for _, result := range shctx.completions(inprogressPrefix) { + if result, ok := strings.CutPrefix(result, inprogressPrefix); ok { + results = append(results, []rune(result+" ")) + } + } + return results, len(inprogressPrefix) +} + +func (h *shellAutoComplete) dispatch(previous *CompletionContext, stmt *syntax.Stmt, cursor uint) *CompletionContext { + if stmt == nil { + return previous + } + switch cmd := stmt.Cmd.(type) { + case *syntax.CallExpr: + return h.dispatchCall(previous, cmd, cursor) + case *syntax.BinaryCmd: + return h.dispatchPipe(previous, cmd, cursor) + } + return nil +} + +func (h *shellAutoComplete) dispatchCall(previous *CompletionContext, call *syntax.CallExpr, cursor uint) *CompletionContext { + if call.Pos().Offset() >= cursor { + // short-circuit calls once we get past the current cursor context + return previous + } + + args := make([]string, 0, len(call.Args)) + for _, arg := range call.Args { + args = append(args, arg.Lit()) + } + return previous.lookupField(args[0], args[1:]) +} + +func (h *shellAutoComplete) dispatchPipe(previous *CompletionContext, pipe *syntax.BinaryCmd, cursor uint) *CompletionContext { + if pipe.Op != syntax.Pipe { + return nil + } + + previous = h.dispatch(previous, pipe.X, cursor) + if previous == nil { + return nil + } + + if pipe.OpPos.Offset() >= cursor { + // short-circuit pipes once we get past the current cursor context + return previous + } + previous = previous.lookupType() + if previous == nil { + return nil + } + + return h.dispatch(previous, pipe.Y, cursor) +} + +func (h *shellAutoComplete) root() *CompletionContext { + return &CompletionContext{ + Completer: h, + root: true, + } +} + +// CompletionContext provides completions for a specific point in a command +// chain. Each point is represented by one of `Mod` prefixed fields being set +// at a time. +type CompletionContext struct { + Completer *shellAutoComplete + + // CmdRoot is the name of a namespace-setting command. + CmdRoot string + + // ModType indicates the completions should be performed on an + // object/interface/etc. + ModType functionProvider + + // ModFunc indicates the completions should be performed on the arguments + // for a function call. + ModFunction *modFunction + + root bool +} + +func (ctx *CompletionContext) completions(prefix string) []string { + var results []string + switch { + case ctx.ModFunction != nil: + // TODO: also complete required args sometimes (depending on type) + + // complete optional args + if strings.HasPrefix(prefix, "-") { + for _, arg := range ctx.ModFunction.OptionalArgs() { + flag := "--" + arg.FlagName() + results = append(results, flag) + } + } + + case ctx.ModType != nil: + // complete possible functions for this type + for _, f := range ctx.ModType.GetFunctions() { + results = append(results, f.CmdName()) + } + // complete potentially chainable builtins + for _, builtin := range ctx.builtins() { + results = append(results, builtin.Name()) + } + + case ctx.root: + for _, cmd := range slices.Concat(ctx.builtins(), ctx.stdlib()) { + results = append(results, cmd.Name()) + } + if md, _ := ctx.Completer.GetModuleDef(nil); md != nil { + for _, fn := range md.MainObject.AsFunctionProvider().GetFunctions() { + results = append(results, fn.CmdName()) + } + for _, dep := range md.Dependencies { + results = append(results, dep.Name) + } + } + for modRef := range ctx.Completer.modDefs { + if modRef != "" { + results = append(results, modRef) + } + } + + case ctx.CmdRoot == shellStdlibCmdName: + for _, cmd := range ctx.Completer.Stdlib() { + results = append(results, cmd.Name()) + } + + case ctx.CmdRoot == shellDepsCmdName: + if md, _ := ctx.Completer.GetModuleDef(nil); md != nil { + for _, dep := range md.Dependencies { + results = append(results, dep.Name) + } + } + + case ctx.CmdRoot == shellCoreCmdName: + for _, fn := range ctx.Completer.modDef(nil).GetCoreFunctions() { + results = append(results, fn.CmdName()) + } + } + + return results +} + +func (ctx *CompletionContext) lookupField(field string, args []string) *CompletionContext { + if cmd := ctx.builtinCmd(field); cmd != nil { + return cmd.Complete(ctx, args) + } + + def := ctx.Completer.modDef(nil) + + if ctx.ModType != nil { + next, err := def.GetFunction(ctx.ModType, field) + if err != nil { + return nil + } + return &CompletionContext{ + Completer: ctx.Completer, + ModFunction: next, + } + } + + // Limit options for these namespace-setting commands + switch ctx.CmdRoot { + case shellStdlibCmdName: + if cmd := ctx.stdlibCmd(field); cmd != nil { + return cmd.Complete(ctx, args) + } + case shellDepsCmdName: + // TODO: loading other modules isn't supported yet + return nil + + case shellCoreCmdName: + if fn := def.GetCoreFunction(field); fn != nil { + return &CompletionContext{ + Completer: ctx.Completer, + ModFunction: fn, + } + } + } + + // Default lookup and fallbacks after this point, which only happens + // when it's the first command. + if !ctx.root { + return nil + } + + // 1. Current module function + if def.HasMainFunction(field) { + next, err := def.GetFunction(def.MainObject.AsFunctionProvider(), field) + if err != nil { + return nil + } + return &CompletionContext{ + Completer: ctx.Completer, + ModFunction: next, + } + } + + // 2. Dependency + if dep := def.GetDependency(field); dep != nil { + // TODO: loading other modules isn't supported yet + return nil + } + + // 3. Stdlib + if cmd := ctx.stdlibCmd(field); cmd != nil { + return cmd.Complete(ctx, args) + } + + // 4. Module reference + // TODO: loading other modules isn't supported yet + if field == ctx.Completer.modRef { + return &CompletionContext{ + Completer: ctx.Completer, + ModFunction: def.MainObject.AsObject.Constructor, + } + } + + return nil +} + +func (ctx *CompletionContext) lookupType() *CompletionContext { + if ctx.ModType != nil || ctx.CmdRoot != "" { + return ctx + } + if ctx.ModFunction != nil { + def := ctx.Completer.modDef(nil) + next := def.GetFunctionProvider(ctx.ModFunction.ReturnType.Name()) + return &CompletionContext{ + Completer: ctx.Completer, + ModType: next, + } + } + return nil +} + +func (ctx *CompletionContext) builtins() []*ShellCommand { + var cmds []*ShellCommand + for _, cmd := range ctx.Completer.Builtins() { + if ctx.root && cmd.State != RequiredState || !ctx.root && cmd.State != NoState { + cmds = append(cmds, cmd) + } + } + return cmds +} + +func (ctx *CompletionContext) stdlib() []*ShellCommand { + var cmds []*ShellCommand + for _, cmd := range ctx.Completer.Stdlib() { + if ctx.root && cmd.State != RequiredState || !ctx.root && cmd.State != NoState { + cmds = append(cmds, cmd) + } + } + return cmds +} + +func (ctx *CompletionContext) builtinCmd(name string) *ShellCommand { + for _, cmd := range ctx.builtins() { + if cmd.Name() == name { + if cmd.Complete == nil { + return nil + } + return cmd + } + } + return nil +} + +func (ctx *CompletionContext) stdlibCmd(name string) *ShellCommand { + for _, cmd := range ctx.stdlib() { + if cmd.Name() == name { + if cmd.Complete == nil { + return nil + } + return cmd + } + } + return nil +} diff --git a/cmd/dagger/shell_completion_test.go b/cmd/dagger/shell_completion_test.go new file mode 100644 index 0000000000..1652124414 --- /dev/null +++ b/cmd/dagger/shell_completion_test.go @@ -0,0 +1,132 @@ +package main + +import ( + "context" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "dagger.io/dagger" + "github.com/stretchr/testify/require" +) + +func TestShellAutocomplete(t *testing.T) { + // each cmdline is a prompt input + // the contents of the angle brackets are the word we want to complete - + // everything before the $ sign is already written, and one of the response + // options should include the contents after the $ + + cmdlines := []string{ + // top-level function + ``, + `<$container >`, + ` <$container >`, + ` "alpine:latest"`, + `| directory`, + + // top-level deps + ``, + + // stdlib fallback + ``, + `directory | `, + + // chaining + `container | `, + `container | directory "./path" | `, + // NOTE: this requires parsing partial parse trees + // "container | <$directory >", + + // subshells + // FIXME: this edge case should probably still work + // `container | with-directory $(<$container >)`, + `container | with-directory $()`, + `container | with-directory $(container | )`, + + // args + `container <--$packages >`, + `container <--$packages > | directory`, + `container | directory <--$expand >`, + + // .deps builtin + `<.dep$s >`, + `<$.deps >`, + `.deps | `, + + // .stdlib builtin + `<.std$lib >`, + `<$.stdlib >`, + `.stdlib | `, + `.stdlib | container <--$platform >`, + `.stdlib | container | `, + + // .core builtin + `<.co$re >`, + `<$.core >`, + `.core | `, + `.core | container <--$platform >`, + `.core | container | `, + + // FIXME: avoid inserting extra spaces + // ` `, + } + + wd, err := os.Getwd() + require.NoError(t, err) + + dir := t.TempDir() + require.NoError(t, os.CopyFS(dir, os.DirFS(filepath.Join(wd, "../../modules")))) + cmd := exec.Command("git", "init") + cmd.Dir = dir + require.NoError(t, cmd.Run()) + + os.Chdir(dir) + t.Cleanup(func() { + os.Chdir(wd) + }) + t.Setenv("DAGGER_MODULE", "./wolfi") + + ctx := context.TODO() + + client, err := dagger.Connect(ctx) + require.NoError(t, err) + t.Cleanup(func() { client.Close() }) + + handler := &shellCallHandler{ + dag: client, + stdin: nil, + stdout: io.Discard, + stderr: io.Discard, + debug: debug, + } + require.NoError(t, handler.RunAll(ctx, nil)) + autoComplete := shellAutoComplete{handler} + + for _, cmdline := range cmdlines { + t.Run(cmdline, func(t *testing.T) { + start := strings.IndexRune(cmdline, '<') + end := strings.IndexRune(cmdline, '>') + if start == -1 || end == -1 || !(start < end) { + require.FailNow(t, "invalid cmdline: could not find ") + } + inprogress, expected, ok := strings.Cut(cmdline[start+1:end], "$") + if !ok { + require.FailNow(t, "invalid cmdline: no token '$' in ") + } + + cmdline := cmdline[:start] + inprogress + cmdline[end+1:] + cursor := start + len(inprogress) + + results, length := autoComplete.Do([]rune(cmdline), cursor) + sresults := make([]string, 0, len(results)) + for _, result := range results { + sresults = append(sresults, string(result)) + } + require.Contains(t, sresults, expected) + require.Equal(t, len(inprogress), length) + }) + } +}